diff --git a/cmd/manager/main.go b/cmd/manager/main.go index 61fce57c6be..71be9d3336e 100644 --- a/cmd/manager/main.go +++ b/cmd/manager/main.go @@ -333,9 +333,8 @@ func doRun(_ *cobra.Command, _ []string) error { toWatch = append(toWatch, filepath.Join(caDir, certificates.KeyFileName), filepath.Join(caDir, certificates.CertFileName), - // TODO support ca.crt and ca.key - // filepath.Join(caDir, certificates.CAKeyFileName), - // filepath.Join(caDir, certificates.CAFileName), + filepath.Join(caDir, certificates.CAKeyFileName), + filepath.Join(caDir, certificates.CAFileName), ) } diff --git a/pkg/controller/common/certificates/ca_reconcile.go b/pkg/controller/common/certificates/ca_reconcile.go index 91e18724398..f7120ad624e 100644 --- a/pkg/controller/common/certificates/ca_reconcile.go +++ b/pkg/controller/common/certificates/ca_reconcile.go @@ -11,6 +11,7 @@ import ( "crypto/x509" "crypto/x509/pkix" "fmt" + "github.com/elastic/cloud-on-k8s/pkg/utils/fs" "io/ioutil" "path/filepath" "time" @@ -234,8 +235,37 @@ func internalSecretForCA( }, nil } +func detectCAFileNames(path string) (string, string, error) { + files := map[string]bool{ + CertFileName: false, + KeyFileName: false, + CAFileName: false, + CAKeyFileName: false, + } + for f := range files { + exists, err := fs.FileExists(filepath.Join(path, f)) + if err != nil { + return "", "", err + } + files[f] = exists + } + switch { + case (files[CertFileName] || files[KeyFileName]) && files[CAKeyFileName]: + return "", "", fmt.Errorf("both tls.* and ca.* files exist, configuration error") + case files[CAFileName] && files[CAKeyFileName]: + return filepath.Join(path, CAFileName), filepath.Join(CAKeyFileName), nil + case files[CertFileName] && files[KeyFileName]: + return filepath.Join(path, CertFileName), filepath.Join(path, KeyFileName), nil + } + return "", "", fmt.Errorf("no CA certificate files found: %+v", files) +} + func BuildCAFromFile(path string) (*CA, error) { - certFile := filepath.Join(path, CertFileName) + certFile, privateKeyFile, err := detectCAFileNames(path) + if err != nil { + return nil, err + } + bytes, err := ioutil.ReadFile(certFile) if err != nil { return nil, err @@ -254,7 +284,6 @@ func BuildCAFromFile(path string) (*CA, error) { } cert := certs[0] - privateKeyFile := filepath.Join(path, KeyFileName) privateKeyBytes, err := ioutil.ReadFile(privateKeyFile) if err != nil { return nil, err diff --git a/pkg/controller/common/certificates/ca_reconcile_integration_test.go b/pkg/controller/common/certificates/ca_reconcile_integration_test.go index a39be727453..00cfb106974 100644 --- a/pkg/controller/common/certificates/ca_reconcile_integration_test.go +++ b/pkg/controller/common/certificates/ca_reconcile_integration_test.go @@ -7,6 +7,7 @@ package certificates import ( + "fmt" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "io/ioutil" @@ -103,3 +104,68 @@ func TestBuildCAFromFile(t *testing.T) { }) } } + +func Test_detectCAFileNames(t *testing.T) { + tests := []struct { + name string + files []string + wantCert string + wantKey string + wantErr bool + }{ + { + name: "happy path ca*", + files: []string{"ca.crt", "ca.key"}, + wantCert: "ca.crt", + wantKey: "ca.key", + wantErr: false, + }, + { + name: "happy path tls*", + files: []string{"tls.crt", "tls.key"}, + wantCert: "tls.crt", + wantKey: "tls.key", + wantErr: false, + }, + { + name: "tls.* with ca.crt OK", + files: []string{"tls.crt", "tls.key", "ca.crt"}, + wantCert: "tls.crt", + wantKey: "tls.key", + wantErr: false, + }, + { + name: "mixed tls.* and ca.* NOK", + files: []string{"tls.crt", "tls.key", "ca.crt", "ca.key"}, + wantCert: "", + wantKey: "", + wantErr: true, + }, + { + name: "no valid combination of files", + files: nil, + wantCert: "", + wantKey: "", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + dir, err := ioutil.TempDir("", "detect_ca_file_names") + require.NoError(t, err) + defer os.RemoveAll(dir) + for _, f := range tt.files { + require.NoError(t, ioutil.WriteFile(filepath.Join(dir, f), []byte("contents"), 0644)) + } + + cert, key, err := detectCAFileNames(dir) + if tt.wantErr != (err != nil) { + t.Errorf(fmt.Sprintf("want err %v got %v,files: %v ", tt.wantErr, err, tt.files)) + } + if err == nil { + assert.Equalf(t, tt.wantCert, filepath.Base(cert), "detectCAFileNames(), files: %v", tt.files) + assert.Equalf(t, tt.wantKey, filepath.Base(key), "detectCAFileNames(), files: %v", tt.files) + } + }) + } +} diff --git a/pkg/utils/fs/utils.go b/pkg/utils/fs/utils.go new file mode 100644 index 00000000000..3e1cf5f6611 --- /dev/null +++ b/pkg/utils/fs/utils.go @@ -0,0 +1,17 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License 2.0; +// you may not use this file except in compliance with the Elastic License 2.0. + +package fs + +import "os" + +func FileExists(file string) (bool, error) { + _, err := os.Stat(file) + if err != nil && os.IsNotExist(err) { + return false, nil + } else if err != nil { + return false, err + } + return true, nil +}