diff --git a/examples/simple_plugin/go.mod b/examples/simple_plugin/go.mod index 78d0a19b9a..95d66b7f03 100644 --- a/examples/simple_plugin/go.mod +++ b/examples/simple_plugin/go.mod @@ -21,6 +21,7 @@ require ( github.com/aws/aws-sdk-go-v2/internal/ini v1.8.1 // indirect github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect + github.com/aws/aws-sdk-go-v2/service/licensemanager v1.27.4 // indirect github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.23.4 // indirect github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 // indirect github.com/aws/aws-sdk-go-v2/service/ssooidc v1.26.5 // indirect diff --git a/examples/simple_plugin/go.sum b/examples/simple_plugin/go.sum index 143c48e399..2a372a40ec 100644 --- a/examples/simple_plugin/go.sum +++ b/examples/simple_plugin/go.sum @@ -25,6 +25,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbL github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHCiSH0jyd6gROjlJtNwov0eGYNz8s8nFcR0jQ= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= +github.com/aws/aws-sdk-go-v2/service/licensemanager v1.27.4 h1:8tRjT7S8LxBRNRP3KtdV9vj9dJPzG1yDvRIqVmznZII= +github.com/aws/aws-sdk-go-v2/service/licensemanager v1.27.4/go.mod h1:AhruhNzkEGM6NxQzGhc0gWvaj/o8FZi/cCoGymOVxyo= github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.23.4 h1:I9yxA99P3rbkzhv8iDykQcel7n03PmlK8GO6NDpOkj0= github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.23.4/go.mod h1:YAiuhtKyLLPdouuDXeFWh4nrDrMqwQqukNvDSyhltbU= github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 h1:zCsFCKvbj25i7p1u94imVoO447I/sFv8qq+lGJhRN0c= diff --git a/go.mod b/go.mod index 1b938cc9af..15202f7167 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require ( github.com/apache/arrow/go/v17 v17.0.0 github.com/aws/aws-sdk-go-v2 v1.30.4 github.com/aws/aws-sdk-go-v2/config v1.27.31 + github.com/aws/aws-sdk-go-v2/service/licensemanager v1.27.4 github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.23.4 github.com/bradleyjkemp/cupaloy/v2 v2.8.0 github.com/cloudquery/cloudquery-api-go v1.13.0 diff --git a/go.sum b/go.sum index 5ebe9d0c7a..6a88150c40 100644 --- a/go.sum +++ b/go.sum @@ -25,6 +25,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4 h1:KypMCbL github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.11.4/go.mod h1:Vz1JQXliGcQktFTN/LN6uGppAIRoLBR2bMvIMP0gOjc= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHCiSH0jyd6gROjlJtNwov0eGYNz8s8nFcR0jQ= github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= +github.com/aws/aws-sdk-go-v2/service/licensemanager v1.27.4 h1:8tRjT7S8LxBRNRP3KtdV9vj9dJPzG1yDvRIqVmznZII= +github.com/aws/aws-sdk-go-v2/service/licensemanager v1.27.4/go.mod h1:AhruhNzkEGM6NxQzGhc0gWvaj/o8FZi/cCoGymOVxyo= github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.23.4 h1:I9yxA99P3rbkzhv8iDykQcel7n03PmlK8GO6NDpOkj0= github.com/aws/aws-sdk-go-v2/service/marketplacemetering v1.23.4/go.mod h1:YAiuhtKyLLPdouuDXeFWh4nrDrMqwQqukNvDSyhltbU= github.com/aws/aws-sdk-go-v2/service/sso v1.22.5 h1:zCsFCKvbj25i7p1u94imVoO447I/sFv8qq+lGJhRN0c= diff --git a/premium/mocks/licensemanager.go b/premium/mocks/licensemanager.go new file mode 100644 index 0000000000..818ad27cb8 --- /dev/null +++ b/premium/mocks/licensemanager.go @@ -0,0 +1,56 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: offline.go + +// Package mocks is a generated GoMock package. +package mocks + +import ( + context "context" + reflect "reflect" + + licensemanager "github.com/aws/aws-sdk-go-v2/service/licensemanager" + gomock "github.com/golang/mock/gomock" +) + +// MockAWSLicenseManagerInterface is a mock of AWSLicenseManagerInterface interface. +type MockAWSLicenseManagerInterface struct { + ctrl *gomock.Controller + recorder *MockAWSLicenseManagerInterfaceMockRecorder +} + +// MockAWSLicenseManagerInterfaceMockRecorder is the mock recorder for MockAWSLicenseManagerInterface. +type MockAWSLicenseManagerInterfaceMockRecorder struct { + mock *MockAWSLicenseManagerInterface +} + +// NewMockAWSLicenseManagerInterface creates a new mock instance. +func NewMockAWSLicenseManagerInterface(ctrl *gomock.Controller) *MockAWSLicenseManagerInterface { + mock := &MockAWSLicenseManagerInterface{ctrl: ctrl} + mock.recorder = &MockAWSLicenseManagerInterfaceMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockAWSLicenseManagerInterface) EXPECT() *MockAWSLicenseManagerInterfaceMockRecorder { + return m.recorder +} + +// CheckoutLicense mocks base method. +func (m *MockAWSLicenseManagerInterface) CheckoutLicense(ctx context.Context, params *licensemanager.CheckoutLicenseInput, optFns ...func(*licensemanager.Options)) (*licensemanager.CheckoutLicenseOutput, error) { + m.ctrl.T.Helper() + varargs := []interface{}{ctx, params} + for _, a := range optFns { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CheckoutLicense", varargs...) + ret0, _ := ret[0].(*licensemanager.CheckoutLicenseOutput) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// CheckoutLicense indicates an expected call of CheckoutLicense. +func (mr *MockAWSLicenseManagerInterfaceMockRecorder) CheckoutLicense(ctx, params interface{}, optFns ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{ctx, params}, optFns...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CheckoutLicense", reflect.TypeOf((*MockAWSLicenseManagerInterface)(nil).CheckoutLicense), varargs...) +} diff --git a/premium/offline.go b/premium/offline.go index 5d3ea56b44..b9034be828 100644 --- a/premium/offline.go +++ b/premium/offline.go @@ -1,18 +1,25 @@ package premium import ( + "context" "crypto/ed25519" _ "embed" "encoding/hex" "encoding/json" "errors" + "fmt" "os" "path/filepath" "slices" "strings" "time" + "github.com/aws/aws-sdk-go-v2/aws" + awsConfig "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/licensemanager" + "github.com/aws/aws-sdk-go-v2/service/licensemanager/types" "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/google/uuid" "github.com/rs/zerolog" ) @@ -41,20 +48,88 @@ var publicKey string var timeFunc = time.Now -func ValidateLicense(logger zerolog.Logger, meta plugin.Meta, licenseFileOrDirectory string) error { - fi, err := os.Stat(licenseFileOrDirectory) +const awsProductSKU = "prod-2trmtbe74klkg" + +//go:generate mockgen -package=mocks -destination=../premium/mocks/licensemanager.go -source=offline.go AWSLicenseManagerInterface +type AWSLicenseManagerInterface interface { + CheckoutLicense(ctx context.Context, params *licensemanager.CheckoutLicenseInput, optFns ...func(*licensemanager.Options)) (*licensemanager.CheckoutLicenseOutput, error) +} + +type CQLicenseClient struct { + logger zerolog.Logger + meta plugin.Meta + licenseFileOrDirectory string + awsLicenseManagerClient AWSLicenseManagerInterface + isMarketplaceLicense bool +} + +type LicenseClientOptions func(updater *CQLicenseClient) + +func WithMeta(meta plugin.Meta) LicenseClientOptions { + return func(cl *CQLicenseClient) { + cl.meta = meta + } +} + +func WithLicenseFileOrDirectory(licenseFileOrDirectory string) LicenseClientOptions { + return func(cl *CQLicenseClient) { + cl.licenseFileOrDirectory = licenseFileOrDirectory + } +} + +func WithAWSLicenseManagerClient(awsLicenseManagerClient AWSLicenseManagerInterface) LicenseClientOptions { + return func(cl *CQLicenseClient) { + cl.awsLicenseManagerClient = awsLicenseManagerClient + } +} + +func NewLicenseClient(ctx context.Context, logger zerolog.Logger, ops ...LicenseClientOptions) (CQLicenseClient, error) { + cl := CQLicenseClient{ + logger: logger, + isMarketplaceLicense: os.Getenv("CQ_AWS_MARKETPLACE_LICENSE") == "true", + } + + for _, op := range ops { + op(&cl) + } + + if cl.isMarketplaceLicense && cl.awsLicenseManagerClient == nil { + cfg, err := awsConfig.LoadDefaultConfig(ctx) + if err != nil { + return cl, fmt.Errorf("failed to load AWS config: %w", err) + } + cl.awsLicenseManagerClient = licensemanager.NewFromConfig(cfg) + } + + return cl, nil +} + +func (lc CQLicenseClient) ValidateLicense(ctx context.Context) error { + // License can be provided via environment variable for AWS Marketplace or CLI flag + switch { + case lc.isMarketplaceLicense: + return lc.validateMarketplaceLicense(ctx) + case lc.licenseFileOrDirectory != "": + return lc.validateCQLicense() + default: + return ErrLicenseNotApplicable + } +} + +func (lc CQLicenseClient) validateCQLicense() error { + fi, err := os.Stat(lc.licenseFileOrDirectory) if err != nil { return err } if !fi.IsDir() { - return validateLicenseFile(logger, meta, licenseFileOrDirectory) + return lc.validateLicenseFile(lc.licenseFileOrDirectory) } found := false var lastError error - err = filepath.WalkDir(licenseFileOrDirectory, func(path string, d os.DirEntry, err error) error { + err = filepath.WalkDir(lc.licenseFileOrDirectory, func(path string, d os.DirEntry, err error) error { if d.IsDir() { - if path == licenseFileOrDirectory { + if path == lc.licenseFileOrDirectory { return nil } return filepath.SkipDir @@ -67,8 +142,8 @@ func ValidateLicense(logger zerolog.Logger, meta plugin.Meta, licenseFileOrDirec return nil } - logger.Debug().Str("path", path).Msg("considering license file") - lastError = validateLicenseFile(logger, meta, path) + lc.logger.Debug().Str("path", path).Msg("considering license file") + lastError = lc.validateLicenseFile(path) switch lastError { case nil: found = true @@ -91,7 +166,7 @@ func ValidateLicense(logger zerolog.Logger, meta plugin.Meta, licenseFileOrDirec return errors.New("failed to validate license directory") } -func validateLicenseFile(logger zerolog.Logger, meta plugin.Meta, licenseFile string) error { +func (lc CQLicenseClient) validateLicenseFile(licenseFile string) error { licenseContents, err := os.ReadFile(licenseFile) if err != nil { return err @@ -103,14 +178,14 @@ func validateLicenseFile(logger zerolog.Logger, meta plugin.Meta, licenseFile st } if len(l.Plugins) > 0 { - ref := strings.Join([]string{meta.Team, string(meta.Kind), meta.Name}, "/") - teamRef := meta.Team + "/*" + ref := strings.Join([]string{lc.meta.Team, string(lc.meta.Kind), lc.meta.Name}, "/") + teamRef := lc.meta.Team + "/*" if !slices.Contains(l.Plugins, ref) && !slices.Contains(l.Plugins, teamRef) { return ErrLicenseNotApplicable } } - return l.IsValid(logger) + return l.IsValid(lc.logger) } func UnpackLicense(lic []byte) (*License, error) { @@ -158,3 +233,28 @@ func (l *License) IsValid(logger zerolog.Logger) error { msg.Time("expires_at", l.ExpiresAt).Msgf("Offline license for %s loaded.", l.LicensedTo) return nil } + +func (lc CQLicenseClient) validateMarketplaceLicense(ctx context.Context) error { + clientToken := uuid.New() + + resp, err := lc.awsLicenseManagerClient.CheckoutLicense(ctx, &licensemanager.CheckoutLicenseInput{ + CheckoutType: types.CheckoutTypeProvisional, + ClientToken: aws.String(clientToken.String()), + ProductSKU: aws.String(awsProductSKU), + Entitlements: []types.EntitlementData{ + { + Name: aws.String("Unlimited"), + Unit: types.EntitlementDataUnitNone, + }, + }, + // This is hardcoded for AWS Marketplace, because this is the only supported value for marketplace licenses + KeyFingerprint: aws.String("aws:294406891311:AWS/Marketplace:issuer-fingerprint"), + }) + if err != nil { + return fmt.Errorf("failed to checkout license: %w", err) + } + if len(resp.EntitlementsAllowed) == 0 { + return errors.New("no entitlements provisioned") + } + return nil +} diff --git a/premium/offline_test.go b/premium/offline_test.go index 8a7685e36b..ff251dadca 100644 --- a/premium/offline_test.go +++ b/premium/offline_test.go @@ -1,13 +1,22 @@ package premium import ( + "context" + "fmt" "os" "path/filepath" "testing" "time" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/licensemanager" + "github.com/aws/aws-sdk-go-v2/service/licensemanager/types" + "github.com/cloudquery/plugin-sdk/v4/faker" "github.com/cloudquery/plugin-sdk/v4/plugin" + "github.com/cloudquery/plugin-sdk/v4/premium/mocks" + "github.com/golang/mock/gomock" "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -153,8 +162,9 @@ func licenseTest(inputPath string, meta plugin.Meta, timeIs time.Time, expectErr timeFunc = func() time.Time { return timeIs } - - err := ValidateLicense(zerolog.Nop(), meta, inputPath) + licenseClient, err := NewLicenseClient(context.TODO(), zerolog.Nop(), WithMeta(meta), WithLicenseFileOrDirectory(inputPath)) + require.NoError(t, err) + err = licenseClient.ValidateLicense(context.TODO()) if expectError == nil { require.NoError(t, err) } else { @@ -162,3 +172,66 @@ func licenseTest(inputPath string, meta plugin.Meta, timeIs time.Time, expectErr } } } + +func TestValidateMarketplaceLicense(t *testing.T) { + ctrl := gomock.NewController(t) + m := mocks.NewMockAWSLicenseManagerInterface(ctrl) + out := licensemanager.CheckoutLicenseOutput{} + in := licenseInput{ + CheckoutLicenseInput: licensemanager.CheckoutLicenseInput{ + CheckoutType: types.CheckoutTypeProvisional, + ProductSKU: aws.String(awsProductSKU), + Entitlements: []types.EntitlementData{ + { + Name: aws.String("Unlimited"), + Unit: types.EntitlementDataUnitNone, + }, + }, + KeyFingerprint: aws.String("aws:294406891311:AWS/Marketplace:issuer-fingerprint"), + }, + } + + assert.NoError(t, faker.FakeObject(&out)) + m.EXPECT().CheckoutLicense(gomock.Any(), in).Return(&out, nil) + t.Setenv("CQ_AWS_MARKETPLACE_LICENSE", "true") + + licenseClient, err := NewLicenseClient(context.TODO(), zerolog.Nop(), WithAWSLicenseManagerClient(m)) + require.NoError(t, err) + require.NoError(t, licenseClient.ValidateLicense(context.TODO())) +} + +type licenseInput struct { + licensemanager.CheckoutLicenseInput +} + +func (li licenseInput) Matches(x any) bool { + testInput, ok := x.(*licensemanager.CheckoutLicenseInput) + if !ok { + return false + } + + if testInput.CheckoutType != li.CheckoutType { + return false + } + + for i, ent := range testInput.Entitlements { + if aws.ToString(ent.Name) != aws.ToString(li.Entitlements[i].Name) { + return false + } + if aws.ToString(ent.Value) != aws.ToString(li.Entitlements[i].Value) { + return false + } + } + + if aws.ToString(testInput.KeyFingerprint) != aws.ToString(li.KeyFingerprint) { + return false + } + if aws.ToString(testInput.ProductSKU) != aws.ToString(li.ProductSKU) { + return false + } + return true +} + +func (li licenseInput) String() string { + return fmt.Sprintf("{CheckoutType:%s Entitlements:%v KeyFingerprint:%s ProductSKU:%s}", li.CheckoutType, li.Entitlements, *li.KeyFingerprint, *li.ProductSKU) +} diff --git a/serve/plugin.go b/serve/plugin.go index 1d55ebdaaa..904b486cf5 100644 --- a/serve/plugin.go +++ b/serve/plugin.go @@ -145,15 +145,17 @@ func (s *PluginServe) newCmdPluginServe() *cobra.Command { defer shutdown() } - if licenseFile != "" { - switch err := premium.ValidateLicense(logger, s.plugin.Meta(), licenseFile); err { - case nil: - s.plugin.SetSkipUsageClient(true) - case premium.ErrLicenseNotApplicable: - // no-op: Treat as if no license was provided - default: - return fmt.Errorf("failed to validate license: %w", err) - } + licenseClient, err := premium.NewLicenseClient(cmd.Context(), logger, premium.WithMeta(s.plugin.Meta()), premium.WithLicenseFileOrDirectory(licenseFile)) + if err != nil { + return fmt.Errorf("failed to create license client: %w", err) + } + switch err := licenseClient.ValidateLicense(cmd.Context()); err { + case nil: + s.plugin.SetSkipUsageClient(true) + case premium.ErrLicenseNotApplicable: + // no-op: Treat as if no license was provided + default: + return fmt.Errorf("failed to validate license: %w", err) } var listener net.Listener