Skip to content

Commit

Permalink
check account id for eas bootstrap steps as well
Browse files Browse the repository at this point in the history
  • Loading branch information
GavinFrazar committed Sep 30, 2024
1 parent 76995c5 commit c54d2ab
Show file tree
Hide file tree
Showing 20 changed files with 64 additions and 57 deletions.
8 changes: 4 additions & 4 deletions lib/integrations/awsoidc/access_graph_aws_sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,13 @@ func (r *AccessGraphAWSIAMConfigureRequest) CheckAndSetDefaults() error {
// AccessGraphIAMConfigureClient describes the required methods to create the IAM Policies
// required for enrolling Access Graph AWS Sync into Teleport.
type AccessGraphIAMConfigureClient interface {
callerIdentityGetter
CallerIdentityGetter
// PutRolePolicy creates or replaces a Policy by its name in a IAM Role.
PutRolePolicy(ctx context.Context, params *iam.PutRolePolicyInput, optFns ...func(*iam.Options)) (*iam.PutRolePolicyOutput, error)
}

type defaultTAGIAMConfigureClient struct {
callerIdentityGetter
CallerIdentityGetter
*iam.Client
}

Expand All @@ -82,7 +82,7 @@ func NewAccessGraphIAMConfigureClient(ctx context.Context) (AccessGraphIAMConfig
}

return &defaultTAGIAMConfigureClient{
callerIdentityGetter: sts.NewFromConfig(cfg),
CallerIdentityGetter: sts.NewFromConfig(cfg),
Client: iam.NewFromConfig(cfg),
}, nil
}
Expand All @@ -96,7 +96,7 @@ func ConfigureAccessGraphSyncIAM(ctx context.Context, clt AccessGraphIAMConfigur
return trace.Wrap(err)
}

if err := checkAccountID(ctx, clt, req.AccountID); err != nil {
if err := CheckAccountID(ctx, clt, req.AccountID); err != nil {
return trace.Wrap(err)
}

Expand Down
4 changes: 2 additions & 2 deletions lib/integrations/awsoidc/access_graph_aws_sync_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestAccessGraphAWSIAMConfig(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
clt := mockAccessGraphAWSAMConfigClient{
callerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
CallerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
existingRoles: tt.mockExistingRoles,
}

Expand All @@ -140,7 +140,7 @@ func TestAccessGraphAWSIAMConfig(t *testing.T) {
}

type mockAccessGraphAWSAMConfigClient struct {
callerIdentityGetter
CallerIdentityGetter
existingRoles []string
}

Expand Down
8 changes: 4 additions & 4 deletions lib/integrations/awsoidc/aws_app_access_iam_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,14 +70,14 @@ func (r *AWSAppAccessConfigureRequest) CheckAndSetDefaults() error {

// AWSAppAccessConfigureClient describes the required methods to create the IAM Policies required for AWS App Access.
type AWSAppAccessConfigureClient interface {
callerIdentityGetter
CallerIdentityGetter
// PutRolePolicy creates or replaces a Policy by its name in a IAM Role.
PutRolePolicy(ctx context.Context, params *iam.PutRolePolicyInput, optFns ...func(*iam.Options)) (*iam.PutRolePolicyOutput, error)
}

type defaultAWSAppAccessConfigureClient struct {
*iam.Client
callerIdentityGetter
CallerIdentityGetter
}

// NewAWSAppAccessConfigureClient creates a new AWSAppAccessConfigureClient.
Expand All @@ -101,7 +101,7 @@ func NewAWSAppAccessConfigureClient(ctx context.Context) (AWSAppAccessConfigureC

return &defaultAWSAppAccessConfigureClient{
Client: iam.NewFromConfig(cfg),
callerIdentityGetter: sts.NewFromConfig(cfg),
CallerIdentityGetter: sts.NewFromConfig(cfg),
}, nil
}

Expand All @@ -116,7 +116,7 @@ func ConfigureAWSAppAccess(ctx context.Context, awsClient AWSAppAccessConfigureC
return trace.Wrap(err)
}

if err := checkAccountID(ctx, awsClient, req.AccountID); err != nil {
if err := CheckAccountID(ctx, awsClient, req.AccountID); err != nil {
return trace.Wrap(err)
}

Expand Down
4 changes: 2 additions & 2 deletions lib/integrations/awsoidc/aws_app_access_iam_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ func TestAWSAppAccessConfig(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
awsClient := &mockAWSAppAccessConfigClient{
callerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
CallerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
existingRoles: tt.mockExistingRoles,
}

Expand All @@ -140,7 +140,7 @@ func TestAWSAppAccessConfig(t *testing.T) {
}

type mockAWSAppAccessConfigClient struct {
callerIdentityGetter
CallerIdentityGetter
existingRoles []string
}

Expand Down
10 changes: 5 additions & 5 deletions lib/integrations/awsoidc/clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ func (j IdentityToken) GetIdentityToken() ([]byte, error) {
return []byte(j), nil
}

type callerIdentityGetter interface {
// CallerIdentityGetter is a subset of [sts.Client] that can be used to information about the caller identity.
type CallerIdentityGetter interface {
// GetCallerIdentity returns information about the caller identity.
GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error)
}

// checkAccountID is a helper func that check if the current caller account ID
// matches the expected account ID, in order to check that the command was run
// in the expected AWS account.
func checkAccountID(ctx context.Context, clt callerIdentityGetter, wantAccountID string) error {
// CheckAccountID is a helper func that check if the current caller account ID
// matches the expected account ID.
func CheckAccountID(ctx context.Context, clt CallerIdentityGetter, wantAccountID string) error {
if wantAccountID == "" {
return nil
}
Expand Down
6 changes: 3 additions & 3 deletions lib/integrations/awsoidc/deployservice.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,11 @@ type DeployServiceClient interface {
// Before deploying the service, it must ensure that the token exists and has the appropriate token rul.
TokenService

callerIdentityGetter
CallerIdentityGetter
}

type defaultDeployServiceClient struct {
callerIdentityGetter
CallerIdentityGetter
*ecs.Client
tokenServiceClient TokenService
}
Expand Down Expand Up @@ -355,7 +355,7 @@ func NewDeployServiceClient(ctx context.Context, clientReq *AWSClientRequest, to

return &defaultDeployServiceClient{
Client: ecsClient,
callerIdentityGetter: stsClient,
CallerIdentityGetter: stsClient,
tokenServiceClient: tokenServiceClient,
}, nil
}
Expand Down
8 changes: 4 additions & 4 deletions lib/integrations/awsoidc/deployservice_iam_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ func (r *DeployServiceIAMConfigureRequest) CheckAndSetDefaults() error {

// DeployServiceIAMConfigureClient describes the required methods to create the IAM Roles/Policies required for the DeployService action.
type DeployServiceIAMConfigureClient interface {
callerIdentityGetter
CallerIdentityGetter

// CreateRole creates a new IAM Role.
CreateRole(ctx context.Context, params *iam.CreateRoleInput, optFns ...func(*iam.Options)) (*iam.CreateRoleOutput, error)
Expand All @@ -124,7 +124,7 @@ type DeployServiceIAMConfigureClient interface {

type defaultDeployServiceIAMConfigureClient struct {
*iam.Client
callerIdentityGetter
CallerIdentityGetter
}

// NewDeployServiceIAMConfigureClient creates a new DeployServiceIAMConfigureClient.
Expand All @@ -140,7 +140,7 @@ func NewDeployServiceIAMConfigureClient(ctx context.Context, region string) (Dep

return &defaultDeployServiceIAMConfigureClient{
Client: iam.NewFromConfig(cfg),
callerIdentityGetter: sts.NewFromConfig(cfg),
CallerIdentityGetter: sts.NewFromConfig(cfg),
}, nil
}

Expand Down Expand Up @@ -169,7 +169,7 @@ func ConfigureDeployServiceIAM(ctx context.Context, clt DeployServiceIAMConfigur
return trace.Wrap(err)
}
req.AccountID = aws.ToString(callerIdentity.Account)
} else if err := checkAccountID(ctx, clt, req.AccountID); err != nil {
} else if err := CheckAccountID(ctx, clt, req.AccountID); err != nil {
return trace.Wrap(err)
}

Expand Down
4 changes: 2 additions & 2 deletions lib/integrations/awsoidc/deployservice_iam_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ func TestDeployServiceIAMConfig(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
clt := mockDeployServiceIAMConfigClient{
callerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
CallerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
existingRoles: tt.mockExistingRoles,
}

Expand All @@ -217,7 +217,7 @@ func TestDeployServiceIAMConfig(t *testing.T) {
}

type mockDeployServiceIAMConfigClient struct {
callerIdentityGetter
CallerIdentityGetter
existingRoles []string
}

Expand Down
8 changes: 4 additions & 4 deletions lib/integrations/awsoidc/ec2_ssm_iam_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (r *EC2SSMIAMConfigureRequest) CheckAndSetDefaults() error {

// EC2SSMConfigureClient describes the required methods to create the IAM Policies and SSM Document required for installing Teleport in EC2 instances.
type EC2SSMConfigureClient interface {
callerIdentityGetter
CallerIdentityGetter

// PutRolePolicy creates or replaces a Policy by its name in a IAM Role.
PutRolePolicy(ctx context.Context, params *iam.PutRolePolicyInput, optFns ...func(*iam.Options)) (*iam.PutRolePolicyOutput, error)
Expand All @@ -120,7 +120,7 @@ type EC2SSMConfigureClient interface {
type defaultEC2SSMConfigureClient struct {
*iam.Client
ssmClient *ssm.Client
callerIdentityGetter
CallerIdentityGetter
}

// CreateDocument creates a Amazon Web Services Systems Manager (SSM document).
Expand All @@ -142,7 +142,7 @@ func NewEC2SSMConfigureClient(ctx context.Context, region string) (EC2SSMConfigu
return &defaultEC2SSMConfigureClient{
Client: iam.NewFromConfig(cfg),
ssmClient: ssm.NewFromConfig(cfg),
callerIdentityGetter: sts.NewFromConfig(cfg),
CallerIdentityGetter: sts.NewFromConfig(cfg),
}, nil
}

Expand All @@ -167,7 +167,7 @@ func ConfigureEC2SSM(ctx context.Context, clt EC2SSMConfigureClient, req EC2SSMI
return trace.Wrap(err)
}

if err := checkAccountID(ctx, clt, req.AccountID); err != nil {
if err := CheckAccountID(ctx, clt, req.AccountID); err != nil {
return trace.Wrap(err)
}

Expand Down
4 changes: 2 additions & 2 deletions lib/integrations/awsoidc/ec2_ssm_iam_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ func TestEC2SSMIAMConfig(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
clt := mockEC2SSMIAMConfigClient{
callerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
CallerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
existingRoles: tt.mockExistingRoles,
}

Expand All @@ -228,7 +228,7 @@ func TestEC2SSMIAMConfig(t *testing.T) {
}

type mockEC2SSMIAMConfigClient struct {
callerIdentityGetter
CallerIdentityGetter
existingRoles []string
existingDocs map[string][]ssmtypes.Tag
}
Expand Down
8 changes: 4 additions & 4 deletions lib/integrations/awsoidc/eice_iam_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ func (r *EICEIAMConfigureRequest) CheckAndSetDefaults() error {

// EICEIAMConfigureClient describes the required methods to create the IAM Policies required for accessing EC2 instances usine EICE.
type EICEIAMConfigureClient interface {
callerIdentityGetter
CallerIdentityGetter
// PutRolePolicy creates or replaces a Policy by its name in a IAM Role.
PutRolePolicy(ctx context.Context, params *iam.PutRolePolicyInput, optFns ...func(*iam.Options)) (*iam.PutRolePolicyOutput, error)
}

type defaultEICEIAMConfigureClient struct {
callerIdentityGetter
CallerIdentityGetter
*iam.Client
}

Expand All @@ -93,7 +93,7 @@ func NewEICEIAMConfigureClient(ctx context.Context, region string) (EICEIAMConfi
}

return &defaultEICEIAMConfigureClient{
callerIdentityGetter: sts.NewFromConfig(cfg),
CallerIdentityGetter: sts.NewFromConfig(cfg),
Client: iam.NewFromConfig(cfg),
}, nil
}
Expand Down Expand Up @@ -130,7 +130,7 @@ func ConfigureEICEIAM(ctx context.Context, clt EICEIAMConfigureClient, req EICEI
return trace.Wrap(err)
}

if err := checkAccountID(ctx, clt, req.AccountID); err != nil {
if err := CheckAccountID(ctx, clt, req.AccountID); err != nil {
return trace.Wrap(err)
}

Expand Down
4 changes: 2 additions & 2 deletions lib/integrations/awsoidc/eice_iam_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func TestEICEIAMConfig(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
clt := mockEICEIAMConfigClient{
callerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
CallerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
existingRoles: tt.mockExistingRoles,
}

Expand All @@ -153,7 +153,7 @@ func TestEICEIAMConfig(t *testing.T) {
}

type mockEICEIAMConfigClient struct {
callerIdentityGetter
CallerIdentityGetter
existingRoles []string
}

Expand Down
8 changes: 4 additions & 4 deletions lib/integrations/awsoidc/eks_iam_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ func (r *EKSIAMConfigureRequest) CheckAndSetDefaults() error {

// EKSIAMConfigureClient describes the required methods to create the IAM Policies required for enrolling EKS clusters into Teleport.
type EKSIAMConfigureClient interface {
callerIdentityGetter
CallerIdentityGetter
// PutRolePolicy creates or replaces a Policy by its name in a IAM Role.
PutRolePolicy(ctx context.Context, params *iam.PutRolePolicyInput, optFns ...func(*iam.Options)) (*iam.PutRolePolicyOutput, error)
}

type defaultEKSEIAMConfigureClient struct {
callerIdentityGetter
CallerIdentityGetter
*iam.Client
}

Expand All @@ -94,7 +94,7 @@ func NewEKSIAMConfigureClient(ctx context.Context, region string) (EKSIAMConfigu

return &defaultEKSEIAMConfigureClient{
Client: iam.NewFromConfig(cfg),
callerIdentityGetter: sts.NewFromConfig(cfg),
CallerIdentityGetter: sts.NewFromConfig(cfg),
}, nil
}

Expand All @@ -118,7 +118,7 @@ func ConfigureEKSIAM(ctx context.Context, clt EKSIAMConfigureClient, req EKSIAMC
return trace.Wrap(err)
}

if err := checkAccountID(ctx, clt, req.AccountID); err != nil {
if err := CheckAccountID(ctx, clt, req.AccountID); err != nil {
return trace.Wrap(err)
}

Expand Down
4 changes: 2 additions & 2 deletions lib/integrations/awsoidc/eks_iam_config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func TestEKSAMConfig(t *testing.T) {
} {
t.Run(tt.name, func(t *testing.T) {
clt := mockEKSIAMConfigClient{
callerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
CallerIdentityGetter: mockSTSClient{accountID: tt.mockAccountID},
existingRoles: tt.mockExistingRoles,
}

Expand All @@ -151,7 +151,7 @@ func TestEKSAMConfig(t *testing.T) {
}

type mockEKSIAMConfigClient struct {
callerIdentityGetter
CallerIdentityGetter
existingRoles []string
}

Expand Down
2 changes: 0 additions & 2 deletions lib/integrations/awsoidc/externalauditstorage_iam_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,6 @@ func ConfigureExternalAuditStorage(
return trace.Wrap(err, "attempting to find caller's AWS account ID: call to sts:GetCallerIdentity failed")
}
policyCfg.Account = aws.ToString(stsResp.Account)
} else if err := checkAccountID(ctx, clt, policyCfg.Account); err != nil {
return trace.Wrap(err)
}

policyDoc, err := awslib.PolicyDocumentForExternalAuditStorage(policyCfg)
Expand Down
6 changes: 3 additions & 3 deletions lib/integrations/awsoidc/idp_iam_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (r *IdPIAMConfigureRequest) CheckAndSetDefaults() error {
// IdPIAMConfigureClient describes the required methods to create the AWS OIDC IdP and a Role that trusts that identity provider.
// There is no guarantee that the client is thread safe.
type IdPIAMConfigureClient interface {
callerIdentityGetter
CallerIdentityGetter

// CreateOpenIDConnectProvider creates an IAM OIDC IdP.
CreateOpenIDConnectProvider(ctx context.Context, params *iam.CreateOpenIDConnectProviderInput, optFns ...func(*iam.Options)) (*iam.CreateOpenIDConnectProviderOutput, error)
Expand All @@ -131,7 +131,7 @@ type defaultIdPIAMConfigureClient struct {

*iam.Client
awsConfig aws.Config
callerIdentityGetter
CallerIdentityGetter
}

// NewIdPIAMConfigureClient creates a new IdPIAMConfigureClient.
Expand All @@ -155,7 +155,7 @@ func NewIdPIAMConfigureClient(ctx context.Context) (IdPIAMConfigureClient, error
httpClient: httpClient,
awsConfig: cfg,
Client: iam.NewFromConfig(cfg),
callerIdentityGetter: sts.NewFromConfig(cfg),
CallerIdentityGetter: sts.NewFromConfig(cfg),
}, nil
}

Expand Down
Loading

0 comments on commit c54d2ab

Please sign in to comment.