From 724c9e78ac3ec3f5104a80197fc4a6f472848ffe Mon Sep 17 00:00:00 2001 From: Nic Klaassen Date: Wed, 25 Sep 2024 11:28:12 -0700 Subject: [PATCH] update AWS KMS to aws-sdk-go-v2 (#46896) * update AWS KMS to aws-sdk-go-v2 * cd integrations/terraform && go mod tidy * cd integrations/event-handler && go mod tidy --- go.mod | 1 + integrations/event-handler/go.mod | 1 + integrations/event-handler/go.sum | 2 + integrations/terraform/go.mod | 1 + integrations/terraform/go.sum | 2 + lib/auth/auth.go | 2 +- lib/auth/keystore/aws_kms.go | 154 ++++++++++++++++------------- lib/auth/keystore/aws_kms_test.go | 142 ++++++++++++-------------- lib/auth/keystore/gcp_kms_test.go | 3 - lib/auth/keystore/keystore_test.go | 16 ++- lib/auth/keystore/manager.go | 9 +- 11 files changed, 168 insertions(+), 165 deletions(-) diff --git a/go.mod b/go.mod index d845750718e9..16be51cfef66 100644 --- a/go.mod +++ b/go.mod @@ -59,6 +59,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/eks v1.48.2 github.com/aws/aws-sdk-go-v2/service/glue v1.95.0 github.com/aws/aws-sdk-go-v2/service/iam v1.35.0 + github.com/aws/aws-sdk-go-v2/service/kms v1.35.3 github.com/aws/aws-sdk-go-v2/service/rds v1.82.2 github.com/aws/aws-sdk-go-v2/service/redshift v1.46.6 github.com/aws/aws-sdk-go-v2/service/s3 v1.61.0 diff --git a/integrations/event-handler/go.mod b/integrations/event-handler/go.mod index d402372afd52..fdd970d9bd27 100644 --- a/integrations/event-handler/go.mod +++ b/integrations/event-handler/go.mod @@ -84,6 +84,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 // indirect + github.com/aws/aws-sdk-go-v2/service/kms v1.35.3 // indirect github.com/aws/aws-sdk-go-v2/service/rds v1.82.2 // indirect github.com/aws/aws-sdk-go-v2/service/s3 v1.61.0 // indirect github.com/aws/aws-sdk-go-v2/service/ssm v1.52.6 // indirect diff --git a/integrations/event-handler/go.sum b/integrations/event-handler/go.sum index 0239f47988fd..81d5f2444de4 100644 --- a/integrations/event-handler/go.sum +++ b/integrations/event-handler/go.sum @@ -764,6 +764,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHC github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 h1:jg16PhLPUiHIj8zYIW6bqzeQSuHVEiWnGA0Brz5Xv2I= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16/go.mod h1:Uyk1zE1VVdsHSU7096h/rwnXDzOzYQVl+FNPhPw7ShY= +github.com/aws/aws-sdk-go-v2/service/kms v1.35.3 h1:UPTdlTOwWUX49fVi7cymEN6hDqCwe3LNv1vi7TXUutk= +github.com/aws/aws-sdk-go-v2/service/kms v1.35.3/go.mod h1:gjDP16zn+WWalyaUqwCCioQ8gU8lzttCCc9jYsiQI/8= github.com/aws/aws-sdk-go-v2/service/rds v1.82.2 h1:kO/fQcueYZvuL5kPzTPQ503cKZj8jyBNg1MlnIqpFPg= github.com/aws/aws-sdk-go-v2/service/rds v1.82.2/go.mod h1:hfUZhydujCniydsJdzZ9bwzX6nUvbfnhhYQeFNREC2I= github.com/aws/aws-sdk-go-v2/service/s3 v1.61.0 h1:Wb544Wh+xfSXqJ/j3R4aX9wrKUoZsJNmilBYZb3mKQ4= diff --git a/integrations/terraform/go.mod b/integrations/terraform/go.mod index bc6b5fb8f4df..075453bb58e1 100644 --- a/integrations/terraform/go.mod +++ b/integrations/terraform/go.mod @@ -97,6 +97,7 @@ require ( github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.3.18 // indirect github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 // indirect github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 // indirect + github.com/aws/aws-sdk-go-v2/service/kms v1.35.3 // indirect github.com/aws/aws-sdk-go-v2/service/rds v1.82.2 // indirect github.com/aws/aws-sdk-go-v2/service/s3 v1.61.0 // indirect github.com/aws/aws-sdk-go-v2/service/ssm v1.52.6 // indirect diff --git a/integrations/terraform/go.sum b/integrations/terraform/go.sum index 03c8fc34ff45..f17064b2b25a 100644 --- a/integrations/terraform/go.sum +++ b/integrations/terraform/go.sum @@ -827,6 +827,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18 h1:tJ5RnkHC github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.11.18/go.mod h1:++NHzT+nAF7ZPrHPsA+ENvsXkOO8wEu+C6RXltAG4/c= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16 h1:jg16PhLPUiHIj8zYIW6bqzeQSuHVEiWnGA0Brz5Xv2I= github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.17.16/go.mod h1:Uyk1zE1VVdsHSU7096h/rwnXDzOzYQVl+FNPhPw7ShY= +github.com/aws/aws-sdk-go-v2/service/kms v1.35.3 h1:UPTdlTOwWUX49fVi7cymEN6hDqCwe3LNv1vi7TXUutk= +github.com/aws/aws-sdk-go-v2/service/kms v1.35.3/go.mod h1:gjDP16zn+WWalyaUqwCCioQ8gU8lzttCCc9jYsiQI/8= github.com/aws/aws-sdk-go-v2/service/rds v1.82.2 h1:kO/fQcueYZvuL5kPzTPQ503cKZj8jyBNg1MlnIqpFPg= github.com/aws/aws-sdk-go-v2/service/rds v1.82.2/go.mod h1:hfUZhydujCniydsJdzZ9bwzX6nUvbfnhhYQeFNREC2I= github.com/aws/aws-sdk-go-v2/service/s3 v1.61.0 h1:Wb544Wh+xfSXqJ/j3R4aX9wrKUoZsJNmilBYZb3mKQ4= diff --git a/lib/auth/auth.go b/lib/auth/auth.go index fa5ad16098f5..958ac86c462d 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -382,8 +382,8 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { keystoreOpts := &keystore.Options{ HostUUID: cfg.HostUUID, ClusterName: cfg.ClusterName, - CloudClients: cfg.CloudClients, AuthPreferenceGetter: cfg.ClusterConfiguration, + FIPS: cfg.FIPS, } if cfg.KeyStoreConfig.PKCS11 != (servicecfg.PKCS11Config{}) { if !modules.GetModules().Features().GetEntitlement(entitlements.HSM).Enabled { diff --git a/lib/auth/keystore/aws_kms.go b/lib/auth/keystore/aws_kms.go index fd599336ec0a..3024dad36383 100644 --- a/lib/auth/keystore/aws_kms.go +++ b/lib/auth/keystore/aws_kms.go @@ -31,20 +31,18 @@ import ( "sync" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/config" + "github.com/aws/aws-sdk-go-v2/service/kms" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "golang.org/x/sync/errgroup" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/retryutils" - "github.com/gravitational/teleport/lib/cloud" - awslib "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/service/servicecfg" ) @@ -58,15 +56,8 @@ const ( pendingKeyTimeout = 30 * time.Second ) -type CloudClientProvider interface { - // GetAWSSTSClient returns AWS STS client for the specified region. - GetAWSSTSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (stsiface.STSAPI, error) - // GetAWSKMSClient returns AWS KMS client for the specified region. - GetAWSKMSClient(ctx context.Context, region string, opts ...cloud.AWSOptionsFn) (kmsiface.KMSAPI, error) -} - type awsKMSKeystore struct { - kms kmsiface.KMSAPI + kms kmsClient clusterName types.ClusterName awsAccount string awsRegion string @@ -75,21 +66,33 @@ type awsKMSKeystore struct { } func newAWSKMSKeystore(ctx context.Context, cfg *servicecfg.AWSKMSConfig, opts *Options) (*awsKMSKeystore, error) { - stsClient, err := opts.CloudClients.GetAWSSTSClient(ctx, cfg.AWSRegion, cloud.WithAmbientCredentials()) - if err != nil { - return nil, trace.Wrap(err) + stsClient, kmsClient := opts.awsSTSClient, opts.awsKMSClient + if stsClient == nil || kmsClient == nil { + useFIPSEndpoint := aws.FIPSEndpointStateUnset + if opts.FIPS { + useFIPSEndpoint = aws.FIPSEndpointStateEnabled + } + awsCfg, err := config.LoadDefaultConfig(ctx, + config.WithRegion(cfg.AWSRegion), + config.WithUseFIPSEndpoint(useFIPSEndpoint), + ) + if err != nil { + return nil, trace.Wrap(err, "loading default AWS config") + } + if stsClient == nil { + stsClient = sts.NewFromConfig(awsCfg) + } + if kmsClient == nil { + kmsClient = kms.NewFromConfig(awsCfg) + } } - id, err := stsClient.GetCallerIdentity(&sts.GetCallerIdentityInput{}) + id, err := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) if err != nil { - return nil, trace.Wrap(err) + return nil, trace.Wrap(err, "checking AWS account of local credentials for AWS KMS") } - if aws.StringValue(id.Account) != cfg.AWSAccount { + if aws.ToString(id.Account) != cfg.AWSAccount { return nil, trace.BadParameter("configured AWS KMS account %q does not match AWS account of ambient credentials %q", - cfg.AWSAccount, aws.StringValue(id.Account)) - } - kmsClient, err := opts.CloudClients.GetAWSKMSClient(ctx, cfg.AWSRegion, cloud.WithAmbientCredentials()) - if err != nil { - return nil, trace.Wrap(err) + cfg.AWSAccount, aws.ToString(id.Account)) } clock := opts.clockworkOverride if clock == nil { @@ -125,11 +128,11 @@ func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites a.logger.InfoContext(ctx, "Creating new AWS KMS keypair.", "algorithm", alg) - output, err := a.kms.CreateKey(&kms.CreateKeyInput{ + output, err := a.kms.CreateKey(ctx, &kms.CreateKeyInput{ Description: aws.String("Teleport CA key"), - KeySpec: &alg, - KeyUsage: aws.String(kms.KeyUsageTypeSignVerify), - Tags: []*kms.Tag{ + KeySpec: alg, + KeyUsage: kmstypes.KeyUsageTypeSignVerify, + Tags: []kmstypes.Tag{ { TagKey: aws.String(clusterTagKey), TagValue: aws.String(a.clusterName.GetClusterName()), @@ -142,7 +145,7 @@ func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites if output.KeyMetadata == nil { return nil, nil, trace.Errorf("KeyMetadata of generated key is nil") } - keyARN := aws.StringValue(output.KeyMetadata.Arn) + keyARN := aws.ToString(output.KeyMetadata.Arn) signer, err := a.newSigner(ctx, keyARN) if err != nil { return nil, nil, trace.Wrap(err) @@ -155,14 +158,14 @@ func (a *awsKMSKeystore) generateKey(ctx context.Context, algorithm cryptosuites return keyID, signer, nil } -func awsAlgorithm(alg cryptosuites.Algorithm) (string, error) { +func awsAlgorithm(alg cryptosuites.Algorithm) (kmstypes.KeySpec, error) { switch alg { case cryptosuites.RSA2048: - return kms.KeySpecRsa2048, nil + return kmstypes.KeySpecRsa2048, nil case cryptosuites.ECDSAP256: - return kms.KeySpecEccNistP256, nil + return kmstypes.KeySpecEccNistP256, nil } - return "", trace.BadParameter("unsupported algorithm: %v", alg) + return "", trace.BadParameter("unsupported algorithm for AWS KMS: %v", alg) } // getSigner returns a crypto.Signer for the given key identifier, if it is found. @@ -177,7 +180,7 @@ func (a *awsKMSKeystore) getSigner(ctx context.Context, rawKey []byte, publicKey type awsKMSSigner struct { keyARN string pub crypto.PublicKey - kms kmsiface.KMSAPI + kms kmsClient } func (a *awsKMSKeystore) newSigner(ctx context.Context, keyARN string) (*awsKMSSigner, error) { @@ -211,7 +214,7 @@ func (a *awsKMSKeystore) getPublicKeyDER(ctx context.Context, keyARN string) ([] timeout := a.clock.NewTimer(pendingKeyTimeout) defer timeout.Stop() for { - output, err := a.kms.GetPublicKeyWithContext(ctx, &kms.GetPublicKeyInput{ + output, err := a.kms.GetPublicKey(ctx, &kms.GetPublicKeyInput{ KeyId: aws.String(keyARN), }) if err == nil { @@ -222,8 +225,8 @@ func (a *awsKMSKeystore) getPublicKeyDER(ctx context.Context, keyARN string) ([] // error types // https://docs.aws.amazon.com/kms/latest/developerguide/programming-eventual-consistency.html var ( - notFound *kms.NotFoundException - invalidState *kms.InvalidStateException + notFound *kmstypes.NotFoundException + invalidState *kmstypes.KMSInvalidStateException ) if !errors.As(err, ¬Found) && !errors.As(err, &invalidState) { return nil, trace.Wrap(err, "unexpected error fetching AWS KMS public key") @@ -257,34 +260,34 @@ func (a *awsKMSSigner) Public() crypto.PublicKey { // Sign signs the message digest. func (a *awsKMSSigner) Sign(rand io.Reader, digest []byte, opts crypto.SignerOpts) (signature []byte, err error) { - var signingAlg string + var signingAlg kmstypes.SigningAlgorithmSpec switch opts.HashFunc() { case crypto.SHA256: switch a.pub.(type) { case *rsa.PublicKey: - signingAlg = kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256 + signingAlg = kmstypes.SigningAlgorithmSpecRsassaPkcs1V15Sha256 case *ecdsa.PublicKey: - signingAlg = kms.SigningAlgorithmSpecEcdsaSha256 + signingAlg = kmstypes.SigningAlgorithmSpecEcdsaSha256 default: return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key type %T", opts.HashFunc(), a.pub) } case crypto.SHA512: switch a.pub.(type) { case *rsa.PublicKey: - signingAlg = kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512 + signingAlg = kmstypes.SigningAlgorithmSpecRsassaPkcs1V15Sha512 case *ecdsa.PublicKey: - signingAlg = kms.SigningAlgorithmSpecEcdsaSha512 + signingAlg = kmstypes.SigningAlgorithmSpecEcdsaSha512 default: return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key type %T", opts.HashFunc(), a.pub) } default: return nil, trace.BadParameter("unsupported hash func %q for AWS KMS key", opts.HashFunc()) } - output, err := a.kms.Sign(&kms.SignInput{ + output, err := a.kms.Sign(context.TODO(), &kms.SignInput{ KeyId: aws.String(a.keyARN), Message: digest, - MessageType: aws.String(kms.MessageTypeDigest), - SigningAlgorithm: aws.String(signingAlg), + MessageType: kmstypes.MessageTypeDigest, + SigningAlgorithm: signingAlg, }) if err != nil { return nil, trace.Wrap(err) @@ -298,9 +301,9 @@ func (a *awsKMSKeystore) deleteKey(ctx context.Context, rawKey []byte) error { if err != nil { return trace.Wrap(err) } - _, err = a.kms.ScheduleKeyDeletion(&kms.ScheduleKeyDeletionInput{ + _, err = a.kms.ScheduleKeyDeletion(ctx, &kms.ScheduleKeyDeletionInput{ KeyId: aws.String(keyID.arn), - PendingWindowInDays: aws.Int64(7), + PendingWindowInDays: aws.Int32(7), }) return trace.Wrap(err, "error deleting AWS KMS key") } @@ -370,29 +373,26 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by } // Check if this key was created by this Teleport cluster. - output, err := a.kms.ListResourceTagsWithContext(ctx, &kms.ListResourceTagsInput{ + output, err := a.kms.ListResourceTags(ctx, &kms.ListResourceTagsInput{ KeyId: aws.String(keyARN), }) if err != nil { - err = awslib.ConvertRequestFailureError(err) - if trace.IsAccessDenied(err) { - // It's entirely expected that we'll not be allowed to fetch - // tags for some keys, don't worry about deleting those. - return nil - } - return trace.Wrap(err, "failed to fetch tags for AWS KMS key %q", keyARN) + // It's entirely expected that we won't be allowed to fetch + // tags for some keys, don't worry about deleting those. + a.logger.DebugContext(ctx, "failed to fetch tags for AWS KMS key, skipping", "key_arn", keyARN, "error", err) + return nil } clusterName := a.clusterName.GetClusterName() - if !slices.ContainsFunc(output.Tags, func(tag *kms.Tag) bool { - return aws.StringValue(tag.TagKey) == clusterTagKey && aws.StringValue(tag.TagValue) == clusterName + if !slices.ContainsFunc(output.Tags, func(tag kmstypes.Tag) bool { + return aws.ToString(tag.TagKey) == clusterTagKey && aws.ToString(tag.TagValue) == clusterName }) { // This key was not created by this Teleport cluster, never delete it. return nil } // Check if this key is not enabled or was created in the past 5 minutes. - describeOutput, err := a.kms.DescribeKeyWithContext(ctx, &kms.DescribeKeyInput{ + describeOutput, err := a.kms.DescribeKey(ctx, &kms.DescribeKeyInput{ KeyId: aws.String(keyARN), }) if err != nil { @@ -401,12 +401,12 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by if describeOutput.KeyMetadata == nil { return trace.Errorf("failed to describe AWS KMS key %q", keyARN) } - if keyState := aws.StringValue(describeOutput.KeyMetadata.KeyState); keyState != "Enabled" { + if keyState := describeOutput.KeyMetadata.KeyState; keyState != kmstypes.KeyStateEnabled { a.logger.InfoContext(ctx, "deleteUnusedKeys skipping AWS KMS key which is not in enabled state.", "key_arn", keyARN, "key_state", keyState) return nil } - creationDate := aws.TimeValue(describeOutput.KeyMetadata.CreationDate) + creationDate := aws.ToTime(describeOutput.KeyMetadata.CreationDate) if a.clock.Now().Sub(creationDate).Abs() < 5*time.Minute { // Never delete keys created in the last 5 minutes in case they were // created by a different auth server and just haven't been added to @@ -438,9 +438,9 @@ func (a *awsKMSKeystore) deleteUnusedKeys(ctx context.Context, activeKeys [][]by for _, keyARN := range keysToDelete { a.logger.InfoContext(ctx, "Deleting unused AWS KMS key.", "key_arn", keyARN) - if _, err := a.kms.ScheduleKeyDeletion(&kms.ScheduleKeyDeletionInput{ + if _, err := a.kms.ScheduleKeyDeletion(ctx, &kms.ScheduleKeyDeletionInput{ KeyId: aws.String(keyARN), - PendingWindowInDays: aws.Int64(7), + PendingWindowInDays: aws.Int32(7), }); err != nil { return trace.Wrap(err, "failed to schedule AWS KMS key %q for deletion", keyARN) } @@ -459,17 +459,17 @@ func (a *awsKMSKeystore) forEachKey(ctx context.Context, fn func(ctx context.Con if marker != "" { markerInput = aws.String(marker) } - output, err := a.kms.ListKeysWithContext(ctx, &kms.ListKeysInput{ + output, err := a.kms.ListKeys(ctx, &kms.ListKeysInput{ Marker: markerInput, - Limit: aws.Int64(1000), + Limit: aws.Int32(1000), }) if err != nil { return trace.Wrap(err, "failed to list AWS KMS keys") } - marker = aws.StringValue(output.NextMarker) - more = aws.BoolValue(output.Truncated) + marker = aws.ToString(output.NextMarker) + more = output.Truncated for _, keyEntry := range output.Keys { - keyArn := aws.StringValue(keyEntry.KeyArn) + keyArn := aws.ToString(keyEntry.KeyArn) errGroup.Go(func() error { return trace.Wrap(fn(ctx, keyArn)) }) @@ -501,3 +501,17 @@ func parseAWSKMSKeyID(raw []byte) (awsKMSKeyID, error) { region: parsedARN.Region, }, nil } + +type kmsClient interface { + CreateKey(context.Context, *kms.CreateKeyInput, ...func(*kms.Options)) (*kms.CreateKeyOutput, error) + GetPublicKey(context.Context, *kms.GetPublicKeyInput, ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) + ListKeys(context.Context, *kms.ListKeysInput, ...func(*kms.Options)) (*kms.ListKeysOutput, error) + ScheduleKeyDeletion(context.Context, *kms.ScheduleKeyDeletionInput, ...func(*kms.Options)) (*kms.ScheduleKeyDeletionOutput, error) + DescribeKey(context.Context, *kms.DescribeKeyInput, ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) + ListResourceTags(context.Context, *kms.ListResourceTagsInput, ...func(*kms.Options)) (*kms.ListResourceTagsOutput, error) + Sign(context.Context, *kms.SignInput, ...func(*kms.Options)) (*kms.SignOutput, error) +} + +type stsClient interface { + GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) +} diff --git a/lib/auth/keystore/aws_kms_test.go b/lib/auth/keystore/aws_kms_test.go index 5b77367b7c2c..b09a3f5a11c8 100644 --- a/lib/auth/keystore/aws_kms_test.go +++ b/lib/auth/keystore/aws_kms_test.go @@ -27,13 +27,11 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/aws/request" - "github.com/aws/aws-sdk-go/service/kms" - "github.com/aws/aws-sdk-go/service/kms/kmsiface" - "github.com/aws/aws-sdk-go/service/sts" - "github.com/aws/aws-sdk-go/service/sts/stsiface" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/aws/arn" + "github.com/aws/aws-sdk-go-v2/service/kms" + kmstypes "github.com/aws/aws-sdk-go-v2/service/kms/types" + "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -42,7 +40,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -72,11 +69,9 @@ func TestAWSKMS_DeleteUnusedKeys(t *testing.T) { ClusterName: clusterName, HostUUID: "uuid", AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, - CloudClients: &cloud.TestCloudClients{ - KMS: fakeKMS, - STS: &fakeAWSSTSClient{ - account: "123456789012", - }, + awsKMSClient: fakeKMS, + awsSTSClient: &fakeAWSSTSClient{ + account: "123456789012", }, clockworkOverride: clock, } @@ -93,7 +88,7 @@ func TestAWSKMS_DeleteUnusedKeys(t *testing.T) { err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) require.NoError(t, err) for _, key := range fakeKMS.keys { - assert.Equal(t, "Enabled", key.state) + assert.Equal(t, kmstypes.KeyStateEnabled, key.state) } // Keys created more than 5 minutes ago should be deleted. @@ -101,31 +96,31 @@ func TestAWSKMS_DeleteUnusedKeys(t *testing.T) { err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) require.NoError(t, err) for _, key := range fakeKMS.keys { - assert.Equal(t, "PendingDeletion", key.state) + assert.Equal(t, kmstypes.KeyStatePendingDeletion, key.state) } // Insert a key created by a different Teleport cluster, it should not be // deleted by the keystore. - output, err := fakeKMS.CreateKey(&kms.CreateKeyInput{ - KeySpec: aws.String(kms.KeySpecEccNistP256), - Tags: []*kms.Tag{ - &kms.Tag{ + output, err := fakeKMS.CreateKey(ctx, &kms.CreateKeyInput{ + KeySpec: kmstypes.KeySpecEccNistP256, + Tags: []kmstypes.Tag{ + kmstypes.Tag{ TagKey: aws.String(clusterTagKey), TagValue: aws.String("other-cluster"), }, }, }) require.NoError(t, err) - otherClusterKeyARN := aws.StringValue(output.KeyMetadata.Arn) + otherClusterKeyARN := aws.ToString(output.KeyMetadata.Arn) clock.Advance(6 * time.Minute) err = keyStore.DeleteUnusedKeys(ctx, nil /*activeKeys*/) require.NoError(t, err) for _, key := range fakeKMS.keys { if key.arn == otherClusterKeyARN { - assert.Equal(t, "Enabled", key.state) + assert.Equal(t, kmstypes.KeyStateEnabled, key.state) } else { - assert.Equal(t, "PendingDeletion", key.state) + assert.Equal(t, kmstypes.KeyStatePendingDeletion, key.state) } } } @@ -144,11 +139,9 @@ func TestAWSKMS_WrongAccount(t *testing.T) { ClusterName: clusterName, HostUUID: "uuid", AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, - CloudClients: &cloud.TestCloudClients{ - KMS: newFakeAWSKMSService(t, clock, "222222222222", "us-west-2", 1000), - STS: &fakeAWSSTSClient{ - account: "222222222222", - }, + awsKMSClient: newFakeAWSKMSService(t, clock, "222222222222", "us-west-2", 1000), + awsSTSClient: &fakeAWSSTSClient{ + account: "222222222222", }, } _, err = NewManager(context.Background(), cfg, opts) @@ -177,11 +170,9 @@ func TestAWSKMS_RetryWhilePending(t *testing.T) { ClusterName: clusterName, HostUUID: "uuid", AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, - CloudClients: &cloud.TestCloudClients{ - KMS: kms, - STS: &fakeAWSSTSClient{ - account: "111111111111", - }, + awsKMSClient: kms, + awsSTSClient: &fakeAWSSTSClient{ + account: "111111111111", }, clockworkOverride: clock, } @@ -221,8 +212,6 @@ func TestAWSKMS_RetryWhilePending(t *testing.T) { } type fakeAWSKMSService struct { - kmsiface.KMSAPI - keys []*fakeAWSKMSKey clock clockwork.Clock account string @@ -243,12 +232,12 @@ func newFakeAWSKMSService(t *testing.T, clock clockwork.Clock, account string, r type fakeAWSKMSKey struct { arn string privKeyPEM []byte - tags []*kms.Tag + tags []kmstypes.Tag creationDate time.Time - state string + state kmstypes.KeyState } -func (f *fakeAWSKMSService) CreateKey(input *kms.CreateKeyInput) (*kms.CreateKeyOutput, error) { +func (f *fakeAWSKMSService) CreateKey(_ context.Context, input *kms.CreateKeyInput, _ ...func(*kms.Options)) (*kms.CreateKeyOutput, error) { id := uuid.NewString() a := arn.ARN{ Partition: "aws", @@ -257,15 +246,15 @@ func (f *fakeAWSKMSService) CreateKey(input *kms.CreateKeyInput) (*kms.CreateKey AccountID: f.account, Resource: id, } - state := "Enabled" + state := kmstypes.KeyStateEnabled if f.keyPendingDuration > 0 { - state = "Pending" + state = kmstypes.KeyStateCreating } var privKeyPEM []byte - switch aws.StringValue(input.KeySpec) { - case kms.KeySpecRsa2048: + switch input.KeySpec { + case kmstypes.KeySpecRsa2048: privKeyPEM = testRSA2048PrivateKeyPEM - case kms.KeySpecEccNistP256: + case kmstypes.KeySpecEccNistP256: signer, err := cryptosuites.GenerateKeyWithAlgorithm(cryptosuites.ECDSAP256) if err != nil { return nil, trace.Wrap(err) @@ -285,20 +274,20 @@ func (f *fakeAWSKMSService) CreateKey(input *kms.CreateKeyInput) (*kms.CreateKey state: state, }) return &kms.CreateKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ + KeyMetadata: &kmstypes.KeyMetadata{ Arn: aws.String(a.String()), KeyId: aws.String(id), }, }, nil } -func (f *fakeAWSKMSService) GetPublicKeyWithContext(ctx context.Context, input *kms.GetPublicKeyInput, opts ...request.Option) (*kms.GetPublicKeyOutput, error) { - key, err := f.findKey(aws.StringValue(input.KeyId)) +func (f *fakeAWSKMSService) GetPublicKey(_ context.Context, input *kms.GetPublicKeyInput, _ ...func(*kms.Options)) (*kms.GetPublicKeyOutput, error) { + key, err := f.findKey(aws.ToString(input.KeyId)) if err != nil { return nil, trace.Wrap(err) } - if key.state != "Enabled" { - return nil, trace.NotFound("key %q is not enabled", aws.StringValue(input.KeyId)) + if key.state != kmstypes.KeyStateEnabled { + return nil, trace.NotFound("key %q is not enabled", aws.ToString(input.KeyId)) } privateKey, err := keys.ParsePrivateKey(key.privKeyPEM) if err != nil { @@ -313,26 +302,26 @@ func (f *fakeAWSKMSService) GetPublicKeyWithContext(ctx context.Context, input * }, nil } -func (f *fakeAWSKMSService) Sign(input *kms.SignInput) (*kms.SignOutput, error) { - key, err := f.findKey(aws.StringValue(input.KeyId)) +func (f *fakeAWSKMSService) Sign(_ context.Context, input *kms.SignInput, _ ...func(*kms.Options)) (*kms.SignOutput, error) { + key, err := f.findKey(aws.ToString(input.KeyId)) if err != nil { return nil, trace.Wrap(err) } - if key.state != "Enabled" { - return nil, trace.NotFound("key %q is not enabled", aws.StringValue(input.KeyId)) + if key.state != kmstypes.KeyStateEnabled { + return nil, trace.NotFound("key %q is not enabled", aws.ToString(input.KeyId)) } signer, err := keys.ParsePrivateKey(key.privKeyPEM) if err != nil { return nil, trace.Wrap(err) } var opts crypto.SignerOpts - switch aws.StringValue(input.SigningAlgorithm) { - case kms.SigningAlgorithmSpecRsassaPkcs1V15Sha256, kms.SigningAlgorithmSpecEcdsaSha256: + switch input.SigningAlgorithm { + case kmstypes.SigningAlgorithmSpecRsassaPkcs1V15Sha256, kmstypes.SigningAlgorithmSpecEcdsaSha256: opts = crypto.SHA256 - case kms.SigningAlgorithmSpecRsassaPkcs1V15Sha512: + case kmstypes.SigningAlgorithmSpecRsassaPkcs1V15Sha512: opts = crypto.SHA512 default: - return nil, trace.BadParameter("unsupported SigningAlgorithm %q", aws.StringValue(input.SigningAlgorithm)) + return nil, trace.BadParameter("unsupported SigningAlgorithm %q", input.SigningAlgorithm) } signature, err := signer.Sign(rand.Reader, input.Message, opts) if err != nil { @@ -343,40 +332,40 @@ func (f *fakeAWSKMSService) Sign(input *kms.SignInput) (*kms.SignOutput, error) }, nil } -func (f *fakeAWSKMSService) ScheduleKeyDeletion(input *kms.ScheduleKeyDeletionInput) (*kms.ScheduleKeyDeletionOutput, error) { - key, err := f.findKey(aws.StringValue(input.KeyId)) +func (f *fakeAWSKMSService) ScheduleKeyDeletion(_ context.Context, input *kms.ScheduleKeyDeletionInput, _ ...func(*kms.Options)) (*kms.ScheduleKeyDeletionOutput, error) { + key, err := f.findKey(aws.ToString(input.KeyId)) if err != nil { return nil, trace.Wrap(err) } - key.state = "PendingDeletion" + key.state = kmstypes.KeyStatePendingDeletion return &kms.ScheduleKeyDeletionOutput{}, nil } -func (f *fakeAWSKMSService) ListKeysWithContext(ctx aws.Context, input *kms.ListKeysInput, opts ...request.Option) (*kms.ListKeysOutput, error) { - pageLimit := min(int(aws.Int64Value(input.Limit)), f.pageLimit) +func (f *fakeAWSKMSService) ListKeys(_ context.Context, input *kms.ListKeysInput, _ ...func(*kms.Options)) (*kms.ListKeysOutput, error) { + pageLimit := min(int(aws.ToInt32(input.Limit)), f.pageLimit) output := &kms.ListKeysOutput{} i := 0 if input.Marker != nil { var err error - i, err = strconv.Atoi(aws.StringValue(input.Marker)) + i, err = strconv.Atoi(aws.ToString(input.Marker)) if err != nil { return nil, trace.Wrap(err) } } for ; i < len(f.keys) && len(output.Keys) < pageLimit; i++ { - output.Keys = append(output.Keys, &kms.KeyListEntry{ + output.Keys = append(output.Keys, kmstypes.KeyListEntry{ KeyArn: aws.String(f.keys[i].arn), }) } if i < len(f.keys) { output.NextMarker = aws.String(strconv.Itoa(i)) - output.Truncated = aws.Bool(true) + output.Truncated = true } return output, nil } -func (f *fakeAWSKMSService) ListResourceTagsWithContext(ctx aws.Context, input *kms.ListResourceTagsInput, opts ...request.Option) (*kms.ListResourceTagsOutput, error) { - key, err := f.findKey(aws.StringValue(input.KeyId)) +func (f *fakeAWSKMSService) ListResourceTags(_ context.Context, input *kms.ListResourceTagsInput, _ ...func(*kms.Options)) (*kms.ListResourceTagsOutput, error) { + key, err := f.findKey(aws.ToString(input.KeyId)) if err != nil { return nil, trace.Wrap(err) } @@ -385,15 +374,15 @@ func (f *fakeAWSKMSService) ListResourceTagsWithContext(ctx aws.Context, input * }, nil } -func (f *fakeAWSKMSService) DescribeKeyWithContext(ctx aws.Context, input *kms.DescribeKeyInput, opts ...request.Option) (*kms.DescribeKeyOutput, error) { - key, err := f.findKey(aws.StringValue(input.KeyId)) +func (f *fakeAWSKMSService) DescribeKey(_ context.Context, input *kms.DescribeKeyInput, _ ...func(*kms.Options)) (*kms.DescribeKeyOutput, error) { + key, err := f.findKey(aws.ToString(input.KeyId)) if err != nil { return nil, trace.Wrap(err) } return &kms.DescribeKeyOutput{ - KeyMetadata: &kms.KeyMetadata{ + KeyMetadata: &kmstypes.KeyMetadata{ CreationDate: aws.Time(key.creationDate), - KeyState: aws.String(key.state), + KeyState: key.state, }, }, nil } @@ -401,29 +390,28 @@ func (f *fakeAWSKMSService) DescribeKeyWithContext(ctx aws.Context, input *kms.D func (f *fakeAWSKMSService) findKey(arn string) (*fakeAWSKMSKey, error) { i := slices.IndexFunc(f.keys, func(k *fakeAWSKMSKey) bool { return k.arn == arn }) if i < 0 { - return nil, &kms.NotFoundException{ - Message_: aws.String(fmt.Sprintf("key %q not found", arn)), + return nil, &kmstypes.NotFoundException{ + Message: aws.String(fmt.Sprintf("key %q not found", arn)), } } key := f.keys[i] - if key.state != "Pending" { + if key.state != kmstypes.KeyStateCreating { return key, nil } if f.clock.Now().Before(key.creationDate.Add(f.keyPendingDuration)) { - return nil, &kms.NotFoundException{ - Message_: aws.String(fmt.Sprintf("key %q not found", arn)), + return nil, &kmstypes.NotFoundException{ + Message: aws.String(fmt.Sprintf("key %q not found", arn)), } } - key.state = "Enabled" + key.state = kmstypes.KeyStateEnabled return key, nil } type fakeAWSSTSClient struct { - stsiface.STSAPI account, arn, userID string } -func (f *fakeAWSSTSClient) GetCallerIdentity(*sts.GetCallerIdentityInput) (*sts.GetCallerIdentityOutput, error) { +func (f *fakeAWSSTSClient) GetCallerIdentity(_ context.Context, _ *sts.GetCallerIdentityInput, _ ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { return &sts.GetCallerIdentityOutput{ Account: aws.String(f.account), Arn: aws.String(f.arn), diff --git a/lib/auth/keystore/gcp_kms_test.go b/lib/auth/keystore/gcp_kms_test.go index fc6a44c5a45d..e43825b8c752 100644 --- a/lib/auth/keystore/gcp_kms_test.go +++ b/lib/auth/keystore/gcp_kms_test.go @@ -48,7 +48,6 @@ import ( "github.com/gravitational/teleport/api/utils/grpc/interceptors" "github.com/gravitational/teleport/api/utils/keys" "github.com/gravitational/teleport/lib/auth/keystore/internal/faketime" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/jwt" "github.com/gravitational/teleport/lib/service/servicecfg" @@ -406,7 +405,6 @@ func TestGCPKMSKeystore(t *testing.T) { ClusterName: clusterName, HostUUID: "test-host-id", AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, - CloudClients: &cloud.TestCloudClients{}, kmsClient: kmsClient, faketimeOverride: clock, }) @@ -684,7 +682,6 @@ func TestGCPKMSDeleteUnusedKeys(t *testing.T) { ClusterName: clusterName, HostUUID: localHostID, AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, - CloudClients: &cloud.TestCloudClients{}, kmsClient: kmsClient, }) require.NoError(t, err, "error while creating test keystore manager") diff --git a/lib/auth/keystore/keystore_test.go b/lib/auth/keystore/keystore_test.go index 2298c01d79d5..0b9e88067776 100644 --- a/lib/auth/keystore/keystore_test.go +++ b/lib/auth/keystore/keystore_test.go @@ -30,7 +30,7 @@ import ( "testing" "time" - "github.com/aws/aws-sdk-go/aws/arn" + "github.com/aws/aws-sdk-go-v2/aws/arn" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -40,7 +40,6 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils/keys" - "github.com/gravitational/teleport/lib/cloud" "github.com/gravitational/teleport/lib/cryptosuites" "github.com/gravitational/teleport/lib/service/servicecfg" "github.com/gravitational/teleport/lib/services" @@ -548,11 +547,9 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack { HostUUID: hostUUID, Logger: logger, AuthPreferenceGetter: &fakeAuthPreferenceGetter{types.SignatureAlgorithmSuite_SIGNATURE_ALGORITHM_SUITE_HSM_V1}, - CloudClients: &cloud.TestCloudClients{ - KMS: newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", 100), - STS: &fakeAWSSTSClient{ - account: "123456789012", - }, + awsKMSClient: newFakeAWSKMSService(t, clock, "123456789012", "us-west-2", 100), + awsSTSClient: &fakeAWSSTSClient{ + account: "123456789012", }, kmsClient: testGCPKMSClient, clockworkOverride: clock, @@ -645,8 +642,9 @@ func newTestPack(ctx context.Context, t *testing.T) *testPack { if config, ok := awsKMSTestConfig(t); ok { opts := baseOpts - opts.CloudClients, err = cloud.NewClients() - require.NoError(t, err) + // Unset the fake clients so this test can use the real AWS clients. + opts.awsKMSClient = nil + opts.awsSTSClient = nil backend, err := newAWSKMSKeystore(ctx, &config.AWSKMS, &opts) require.NoError(t, err) diff --git a/lib/auth/keystore/manager.go b/lib/auth/keystore/manager.go index 3fb4ebac1130..259efc7ff85b 100644 --- a/lib/auth/keystore/manager.go +++ b/lib/auth/keystore/manager.go @@ -153,9 +153,11 @@ type Options struct { Logger *slog.Logger // AuthPreferenceGetter provides the current cluster auth preference. AuthPreferenceGetter cryptosuites.AuthPreferenceGetter - // CloudClients provides cloud clients. - CloudClients CloudClientProvider + // FIPS means FedRAMP/FIPS 140-2 compliant configuration was requested. + FIPS bool + awsKMSClient kmsClient + awsSTSClient stsClient kmsClient *kms.KeyManagementClient clockworkOverride clockwork.Clock // GCPKMS uses a special fake clock that seemed more testable at the time. @@ -167,9 +169,6 @@ func (opts *Options) CheckAndSetDefaults() error { if opts.ClusterName == nil { return trace.BadParameter("ClusterName is required") } - if opts.CloudClients == nil { - return trace.BadParameter("CloudClients is required") - } if opts.AuthPreferenceGetter == nil { return trace.BadParameter("AuthPreferenceGetter is required") }