diff --git a/cmd/credential_process.go b/cmd/credential_process.go index 09ae52b..63bbdb6 100644 --- a/cmd/credential_process.go +++ b/cmd/credential_process.go @@ -89,7 +89,7 @@ func generateCredentialProcessConfig(destination string) error { if destination == "" { return fmt.Errorf("no destination provided") } - client, err := creds.GetClient(region) + client, err := creds.GetClient() if err != nil { return err } diff --git a/cmd/helpers.go b/cmd/helpers.go index 579d2d7..d8ef464 100644 --- a/cmd/helpers.go +++ b/cmd/helpers.go @@ -197,7 +197,7 @@ func preInteractiveCheck(region string, client *creds.Client) (*creds.Client, er // If a client was not provided, create one using the provided region if client == nil { var err error - client, err = creds.GetClient(region) + client, err = creds.GetClient() if err != nil { return nil, err } diff --git a/cmd/list.go b/cmd/list.go index dd49ce5..17dc219 100644 --- a/cmd/list.go +++ b/cmd/list.go @@ -47,7 +47,7 @@ var listCmd = &cobra.Command{ } func roleList() (string, error) { - client, err := creds.GetClient(region) + client, err := creds.GetClient() if err != nil { return "", err } diff --git a/cmd/open.go b/cmd/open.go index f37a06c..ee19089 100644 --- a/cmd/open.go +++ b/cmd/open.go @@ -53,7 +53,7 @@ func runOpen(cmd *cobra.Command, args []string) error { return errors.New("Resource type sns and sqs require region in the arn") } var resourceURL string - client, err := creds.GetClient(region) + client, err := creds.GetClient() if err != nil { logging.LogError(err, "Error getting client") return err diff --git a/pkg/creds/consoleme.go b/pkg/creds/consoleme.go index 9354574..d5f7ee4 100644 --- a/pkg/creds/consoleme.go +++ b/pkg/creds/consoleme.go @@ -30,13 +30,15 @@ import ( "strings" "time" + "github.com/netflix/weep/pkg/httpAuth" + "github.com/netflix/weep/pkg/httpAuth/custom" + "github.com/netflix/weep/pkg/util" "github.com/netflix/weep/pkg/aws" "github.com/netflix/weep/pkg/config" werrors "github.com/netflix/weep/pkg/errors" "github.com/netflix/weep/pkg/httpAuth/challenge" - "github.com/netflix/weep/pkg/httpAuth/mtls" "github.com/netflix/weep/pkg/logging" "github.com/netflix/weep/pkg/metadata" @@ -48,8 +50,6 @@ import ( var clientVersion = fmt.Sprintf("%s", metadata.Version) var userAgent = "weep/" + clientVersion + " Go-http-client/1.1" -var clientFactoryOverride ClientFactory -var preflightFunctions = make([]RequestPreflight, 0) // HTTPClient is the interface we expect HTTP clients to implement. type HTTPClient interface { @@ -66,65 +66,15 @@ type Client struct { Region string } -type ClientFactory func() (*http.Client, error) - -// RegisterClientFactory overrides Weep's standard config-based ConsoleMe client -// creation with a ClientFactory. This function will be called during the creation -// of all ConsoleMe clients. -func RegisterClientFactory(factory ClientFactory) { - clientFactoryOverride = factory -} - -type RequestPreflight func(req *http.Request) error - -// RegisterRequestPreflight adds a RequestPreflight function which will be called in the -// order of registration during the creation of a ConsoleMe request. -func RegisterRequestPreflight(preflight RequestPreflight) { - preflightFunctions = append(preflightFunctions, preflight) -} - // GetClient creates an authenticated ConsoleMe client -func GetClient(region string) (*Client, error) { +func GetClient() (*Client, error) { var client *Client consoleMeUrl := viper.GetString("consoleme_url") - authenticationMethod := viper.GetString("authentication_method") - - if clientFactoryOverride != nil { - customClient, err := clientFactoryOverride() - if err != nil { - return client, err - } - client, err = NewClient(consoleMeUrl, "", customClient) - if err != nil { - return client, err - } - } else if authenticationMethod == "mtls" { - mtlsClient, err := mtls.NewHTTPClient() - if err != nil { - return client, err - } - client, err = NewClient(consoleMeUrl, "", mtlsClient) - if err != nil { - return client, err - } - } else if authenticationMethod == "challenge" { - err := challenge.RefreshChallenge() - if err != nil { - return client, err - } - httpClient, err := challenge.NewHTTPClient(consoleMeUrl) - if err != nil { - return client, err - } - client, err = NewClient(consoleMeUrl, "", httpClient) - if err != nil { - return client, err - } - } else { - return nil, fmt.Errorf("Authentication method unsupported or not provided.") + httpClient, err := httpAuth.GetAuthenticatedClient() + if err != nil { + return client, err } - - return client, nil + return NewClient(consoleMeUrl, "", httpClient) } // NewClient takes a ConsoleMe hostname and *http.Client, and returns a @@ -147,18 +97,6 @@ func NewClient(hostname string, region string, httpc *http.Client) (*Client, err return c, nil } -func runPreflightFunctions(req *http.Request) error { - var err error - if preflightFunctions != nil { - for _, preflight := range preflightFunctions { - if err = preflight(req); err != nil { - return err - } - } - } - return nil -} - func (c *Client) buildRequest(method string, resource string, body io.Reader, apiPrefix string) (*http.Request, error) { urlStr := c.Host + apiPrefix + resource req, err := http.NewRequest(method, urlStr, body) @@ -167,7 +105,7 @@ func (c *Client) buildRequest(method string, resource string, body io.Reader, ap } req.Header.Set("User-Agent", userAgent) req.Header.Add("Content-Type", "application/json") - err = runPreflightFunctions(req) + err = custom.RunPreflightFunctions(req) if err != nil { return nil, err } @@ -579,7 +517,7 @@ func GetCredentialsC(client HTTPClient, role string, ipRestrict bool, assumeRole // GetCredentials requests credentials from ConsoleMe then follows the provided chain of roles to // assume. Roles are assumed in the order in which they appear in the assumeRole slice. func GetCredentials(role string, ipRestrict bool, assumeRole []string, region string) (*aws.Credentials, error) { - client, err := GetClient(region) + client, err := GetClient() if err != nil { return nil, err } diff --git a/pkg/httpAuth/custom/custom.go b/pkg/httpAuth/custom/custom.go new file mode 100644 index 0000000..8751fec --- /dev/null +++ b/pkg/httpAuth/custom/custom.go @@ -0,0 +1,45 @@ +package custom + +import "net/http" + +var isOverridden bool +var clientFactoryOverride ClientFactory +var preflightFunctions = make([]RequestPreflight, 0) + +type ClientFactory func() (*http.Client, error) + +func UseCustom() bool { + return isOverridden +} + +func NewHTTPClient() (*http.Client, error) { + return clientFactoryOverride() +} + +// RegisterClientFactory overrides Weep's standard config-based ConsoleMe client +// creation with a ClientFactory. This function will be called during the creation +// of all ConsoleMe clients. +func RegisterClientFactory(factory ClientFactory) { + clientFactoryOverride = factory + isOverridden = true +} + +type RequestPreflight func(req *http.Request) error + +// RegisterRequestPreflight adds a RequestPreflight function which will be called in the +// order of registration during the creation of a ConsoleMe request. +func RegisterRequestPreflight(preflight RequestPreflight) { + preflightFunctions = append(preflightFunctions, preflight) +} + +func RunPreflightFunctions(req *http.Request) error { + var err error + if preflightFunctions != nil { + for _, preflight := range preflightFunctions { + if err = preflight(req); err != nil { + return err + } + } + } + return nil +} diff --git a/pkg/httpAuth/httpAuth.go b/pkg/httpAuth/httpAuth.go new file mode 100644 index 0000000..2ad46ae --- /dev/null +++ b/pkg/httpAuth/httpAuth.go @@ -0,0 +1,28 @@ +package httpAuth + +import ( + "fmt" + "net/http" + + "github.com/netflix/weep/pkg/httpAuth/challenge" + "github.com/netflix/weep/pkg/httpAuth/custom" + "github.com/netflix/weep/pkg/httpAuth/mtls" + "github.com/spf13/viper" +) + +func GetAuthenticatedClient() (*http.Client, error) { + authenticationMethod := viper.GetString("authentication_method") + consoleMeUrl := viper.GetString("consoleme_url") + if custom.UseCustom() { + return custom.NewHTTPClient() + } else if authenticationMethod == "mtls" { + return mtls.NewHTTPClient() + } else if authenticationMethod == "challenge" { + err := challenge.RefreshChallenge() + if err != nil { + return nil, err + } + return challenge.NewHTTPClient(consoleMeUrl) + } + return nil, fmt.Errorf("Authentication method unsupported or not provided.") +} diff --git a/pkg/server/ecsCredentialsHandler.go b/pkg/server/ecsCredentialsHandler.go index 4314144..d01ebea 100644 --- a/pkg/server/ecsCredentialsHandler.go +++ b/pkg/server/ecsCredentialsHandler.go @@ -56,7 +56,7 @@ func parseAssumeRoleQuery(r *http.Request) ([]string, error) { func getCredentialHandler(region string) func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { - var client, err = creds.GetClient(region) + var client, err = creds.GetClient() if err != nil { logging.Log.Error(err) util.WriteError(w, err.Error(), http.StatusBadRequest) diff --git a/pkg/server/server.go b/pkg/server/server.go index b31c352..515949f 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -31,7 +31,7 @@ func Run(host string, port int, role, region string, shutdown chan os.Signal) er if isServingIMDS { logging.Log.Infof("Configuring weep IMDS service for role %s", role) - client, err := creds.GetClient(region) + client, err := creds.GetClient() if err != nil { return err } diff --git a/pkg/swag/swag.go b/pkg/swag/swag.go index 3280942..cfb5e16 100644 --- a/pkg/swag/swag.go +++ b/pkg/swag/swag.go @@ -5,7 +5,8 @@ import ( "fmt" "net/http" - "github.com/netflix/weep/pkg/httpAuth/mtls" + "github.com/netflix/weep/pkg/creds" + "github.com/spf13/viper" ) @@ -15,7 +16,11 @@ type SwagResponse struct { func getClient() (*http.Client, error) { if viper.GetBool("swag.use_mtls") { - return mtls.NewHTTPClient() + consoleMeClient, err := creds.GetClient() + if err != nil { + return nil, err + } + return &consoleMeClient.Client, nil } return http.DefaultClient, nil }