diff --git a/aws/s3_storage_client.go b/aws/s3_storage_client.go index db4bafcf3841..2c7ec0fa71fa 100644 --- a/aws/s3_storage_client.go +++ b/aws/s3_storage_client.go @@ -72,6 +72,7 @@ type S3Config struct { SSEEncryption bool `yaml:"sse_encryption"` HTTPConfig HTTPConfig `yaml:"http_config"` SignatureVersion string `yaml:"signature_version"` + SSEConfig SSEConfig `yaml:"sse"` Inject InjectRequestMiddleware `yaml:"-"` } @@ -100,7 +101,11 @@ func (cfg *S3Config) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { f.StringVar(&cfg.AccessKeyID, prefix+"s3.access-key-id", "", "AWS Access Key ID") f.StringVar(&cfg.SecretAccessKey, prefix+"s3.secret-access-key", "", "AWS Secret Access Key") f.BoolVar(&cfg.Insecure, prefix+"s3.insecure", false, "Disable https on s3 connection.") - f.BoolVar(&cfg.SSEEncryption, prefix+"s3.sse-encryption", false, "Enable AES256 AWS Server Side Encryption") + + // TODO Remove in Cortex 1.9.0 + f.BoolVar(&cfg.SSEEncryption, prefix+"s3.sse-encryption", false, "Enable AWS Server Side Encryption [Deprecated: Use .sse instead. if s3.sse-encryption is enabled, it assumes .sse.type SSE-S3]") + + cfg.SSEConfig.RegisterFlagsWithPrefix(prefix+"s3.sse.", f) f.DurationVar(&cfg.HTTPConfig.IdleConnTimeout, prefix+"s3.http.idle-conn-timeout", 90*time.Second, "The maximum amount of time an idle connection will be held open.") f.DurationVar(&cfg.HTTPConfig.ResponseHeaderTimeout, prefix+"s3.http.response-header-timeout", 0, "If non-zero, specifies the amount of time to wait for a server's response headers after fully writing the request.") @@ -117,9 +122,9 @@ func (cfg *S3Config) Validate() error { } type S3ObjectClient struct { - bucketNames []string - S3 s3iface.S3API - sseEncryption *string + bucketNames []string + S3 s3iface.S3API + sseConfig *SSEParsedConfig } // NewS3ObjectClient makes a new S3-backed ObjectClient. @@ -140,19 +145,34 @@ func NewS3ObjectClient(cfg S3Config) (*S3ObjectClient, error) { s3Client.Handlers.Sign.Swap(v4.SignRequestHandler.Name, v2SignRequestHandler(cfg)) } - var sseEncryption *string - if cfg.SSEEncryption { - sseEncryption = aws.String("AES256") + sseCfg, err := buildSSEParsedConfig(cfg) + if err != nil { + return nil, errors.Wrap(err, "failed to build SSE config") } client := S3ObjectClient{ - S3: s3Client, - bucketNames: bucketNames, - sseEncryption: sseEncryption, + S3: s3Client, + bucketNames: bucketNames, + sseConfig: sseCfg, } return &client, nil } +func buildSSEParsedConfig(cfg S3Config) (*SSEParsedConfig, error) { + if cfg.SSEConfig.Type != "" { + return NewSSEParsedConfig(cfg.SSEConfig) + } + + // deprecated, but if used it assumes SSE-S3 type + if cfg.SSEEncryption { + return NewSSEParsedConfig(SSEConfig{ + Type: SSES3, + }) + } + + return nil, nil +} + func v2SignRequestHandler(cfg S3Config) request.NamedHandler { return request.NamedHandler{ Name: "v2.SignRequestHandler", @@ -324,15 +344,22 @@ func (a *S3ObjectClient) GetObject(ctx context.Context, objectKey string) (io.Re return resp.Body, nil } -// Put object into the store +// PutObject into the store func (a *S3ObjectClient) PutObject(ctx context.Context, objectKey string, object io.ReadSeeker) error { return instrument.CollectedRequest(ctx, "S3.PutObject", s3RequestDuration, instrument.ErrorCode, func(ctx context.Context) error { - _, err := a.S3.PutObjectWithContext(ctx, &s3.PutObjectInput{ - Body: object, - Bucket: aws.String(a.bucketFromKey(objectKey)), - Key: aws.String(objectKey), - ServerSideEncryption: a.sseEncryption, - }) + putObjectInput := &s3.PutObjectInput{ + Body: object, + Bucket: aws.String(a.bucketFromKey(objectKey)), + Key: aws.String(objectKey), + } + + if a.sseConfig != nil { + putObjectInput.ServerSideEncryption = aws.String(a.sseConfig.ServerSideEncryption) + putObjectInput.SSEKMSKeyId = a.sseConfig.KMSKeyID + putObjectInput.SSEKMSEncryptionContext = a.sseConfig.KMSEncryptionContext + } + + _, err := a.S3.PutObjectWithContext(ctx, putObjectInput) return err }) } diff --git a/aws/sse_config.go b/aws/sse_config.go new file mode 100644 index 000000000000..b62000fdc4d4 --- /dev/null +++ b/aws/sse_config.go @@ -0,0 +1,86 @@ +package aws + +import ( + "encoding/base64" + "encoding/json" + "flag" + + "github.com/pkg/errors" +) + +const ( + // SSEKMS config type constant to configure S3 server side encryption using KMS + // https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingKMSEncryption.html + SSEKMS = "SSE-KMS" + sseKMSType = "aws:kms" + // SSES3 config type constant to configure S3 server side encryption with AES-256 + // https://docs.aws.amazon.com/AmazonS3/latest/dev/UsingServerSideEncryption.html + SSES3 = "SSE-S3" + sseS3Type = "AES256" +) + +// SSEParsedConfig configures server side encryption (SSE) +// struct used internally to configure AWS S3 +type SSEParsedConfig struct { + ServerSideEncryption string + KMSKeyID *string + KMSEncryptionContext *string +} + +// SSEConfig configures S3 server side encryption +// struct that is going to receive user input (through config file or CLI) +type SSEConfig struct { + Type string `yaml:"type"` + KMSKeyID string `yaml:"kms_key_id"` + KMSEncryptionContext string `yaml:"kms_encryption_context"` +} + +// RegisterFlagsWithPrefix adds the flags required to config this to the given FlagSet +func (cfg *SSEConfig) RegisterFlagsWithPrefix(prefix string, f *flag.FlagSet) { + f.StringVar(&cfg.Type, prefix+"type", "", "Enable AWS Server Side Encryption. Only SSE-S3 and SSE-KMS are supported") + f.StringVar(&cfg.KMSKeyID, prefix+"kms-key-id", "", "KMS Key ID used to encrypt objects in S3") + f.StringVar(&cfg.KMSEncryptionContext, prefix+"kms-encryption-context", "", "KMS Encryption Context used for object encryption. It expects a JSON as a string.") +} + +// NewSSEParsedConfig creates a struct to configure server side encryption (SSE) +func NewSSEParsedConfig(cfg SSEConfig) (*SSEParsedConfig, error) { + switch cfg.Type { + case SSES3: + return &SSEParsedConfig{ + ServerSideEncryption: sseS3Type, + }, nil + case SSEKMS: + if cfg.KMSKeyID == "" { + return nil, errors.New("KMS key id must be passed when SSE-KMS encryption is selected") + } + + parsedKMSEncryptionContext, err := parseKMSEncryptionContext(cfg.KMSEncryptionContext) + if err != nil { + return nil, errors.Wrap(err, "failed to parse KMS encryption context") + } + + return &SSEParsedConfig{ + ServerSideEncryption: sseKMSType, + KMSKeyID: &cfg.KMSKeyID, + KMSEncryptionContext: parsedKMSEncryptionContext, + }, nil + default: + return nil, errors.New("SSE type is empty or invalid") + } +} + +func parseKMSEncryptionContext(kmsEncryptionContext string) (*string, error) { + if kmsEncryptionContext == "" { + return nil, nil + } + + // validates if kmsEncryptionContext is a valid JSON + jsonKMSEncryptionContext, err := json.Marshal(json.RawMessage(kmsEncryptionContext)) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal KMS encryption context") + } + + parsedKMSEncryptionContext := base64.StdEncoding.EncodeToString([]byte(jsonKMSEncryptionContext)) + + return &parsedKMSEncryptionContext, nil +} diff --git a/aws/sse_config_test.go b/aws/sse_config_test.go new file mode 100644 index 000000000000..7c6cfc4247f4 --- /dev/null +++ b/aws/sse_config_test.go @@ -0,0 +1,90 @@ +package aws + +import ( + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" +) + +func TestNewSSEParsedConfig(t *testing.T) { + kmsKeyID := "test" + kmsEncryptionContext := `{"a": "bc", "b": "cd"}` + // compact form of kmsEncryptionContext + parsedKMSEncryptionContext := "eyJhIjoiYmMiLCJiIjoiY2QifQ==" + + tests := []struct { + name string + params SSEConfig + expected *SSEParsedConfig + expectedErr error + }{ + { + name: "Test SSE encryption with SSES3 type", + params: SSEConfig{ + Type: SSES3, + }, + expected: &SSEParsedConfig{ + ServerSideEncryption: sseS3Type, + }, + }, + { + name: "Test SSE encryption with SSEKMS type without context", + params: SSEConfig{ + Type: SSEKMS, + KMSKeyID: kmsKeyID, + }, + expected: &SSEParsedConfig{ + ServerSideEncryption: sseKMSType, + KMSKeyID: &kmsKeyID, + }, + }, + { + name: "Test SSE encryption with SSEKMS type with context", + params: SSEConfig{ + Type: SSEKMS, + KMSKeyID: kmsKeyID, + KMSEncryptionContext: kmsEncryptionContext, + }, + expected: &SSEParsedConfig{ + ServerSideEncryption: sseKMSType, + KMSKeyID: &kmsKeyID, + KMSEncryptionContext: &parsedKMSEncryptionContext, + }, + }, + { + name: "Test invalid SSE type", + params: SSEConfig{ + Type: "invalid", + }, + expectedErr: errors.New("SSE type is empty or invalid"), + }, + { + name: "Test SSE encryption with SSEKMS type without KMS Key ID", + params: SSEConfig{ + Type: SSEKMS, + KMSKeyID: "", + }, + expectedErr: errors.New("KMS key id must be passed when SSE-KMS encryption is selected"), + }, + { + name: "Test SSE with invalid KMS encryption context JSON", + params: SSEConfig{ + Type: SSEKMS, + KMSKeyID: kmsKeyID, + KMSEncryptionContext: `INVALID_JSON`, + }, + expectedErr: errors.New("failed to parse KMS encryption context: failed to marshal KMS encryption context: json: error calling MarshalJSON for type json.RawMessage: invalid character 'I' looking for beginning of value"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := NewSSEParsedConfig(tt.params) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr.Error(), err.Error()) + } + assert.Equal(t, tt.expected, result) + }) + } +}