Skip to content

Commit

Permalink
cache credential chain results
Browse files Browse the repository at this point in the history
  • Loading branch information
sbiscigl committed Sep 26, 2023
1 parent 16f2ed6 commit e69f664
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@ namespace Aws
void AddProvider(const std::shared_ptr<AWSCredentialsProvider>& provider) { m_providerChain.push_back(provider); }


private:
private:
Aws::Vector<std::shared_ptr<AWSCredentialsProvider> > m_providerChain;
std::shared_ptr<AWSCredentialsProvider> m_cachedProvider;
mutable Aws::Utils::Threading::ReaderWriterLock m_cachedProviderLock;
};

/**
Expand Down
12 changes: 11 additions & 1 deletion src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <aws/core/utils/logging/LogMacros.h>

using namespace Aws::Auth;
using namespace Aws::Utils::Threading;

static const char AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI[] = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI";
static const char AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI[] = "AWS_CONTAINER_CREDENTIALS_FULL_URI";
Expand All @@ -21,15 +22,24 @@ static const char DefaultCredentialsProviderChainTag[] = "DefaultAWSCredentialsP

AWSCredentials AWSCredentialsProviderChain::GetAWSCredentials()
{
ReaderLockGuard lock(m_cachedProviderLock);
if (m_cachedProvider) {
AWSCredentials credentials = m_cachedProvider->GetAWSCredentials();
if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty())
{
return credentials;
}
}
lock.UpgradeToWriterLock();
for (auto&& credentialsProvider : m_providerChain)
{
AWSCredentials credentials = credentialsProvider->GetAWSCredentials();
if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty())
{
m_cachedProvider = credentialsProvider;
return credentials;
}
}

return AWSCredentials();
}

Expand Down
130 changes: 126 additions & 4 deletions tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,20 @@
*/

#include <aws/testing/AwsCppSdkGTestSuite.h>

#include <aws/testing/mocks/aws/auth/MockAWSHttpResourceClient.h>
#include <aws/testing/platform/PlatformTesting.h>
#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/core/platform/Environment.h>
#include <aws/core/platform/FileSystem.h>
#include <aws/core/utils/UnreferencedParam.h>
#include <aws/core/utils/memory/stl/AWSStreamFwd.h>
#include <aws/core/utils/memory/stl/AWSStringStream.h>
#include <aws/core/utils/FileSystemUtils.h>
#include <aws/core/config/AWSProfileConfigLoader.h>
#include <aws/core/auth/AWSCredentialsProviderChain.h>
#include <aws/core/client/AWSError.h>
#include <aws/testing/mocks/http/MockHttpClient.h>
#include <aws/core/http/standard/StandardHttpResponse.h>
#include <aws/core/auth/STSCredentialsProvider.h>
#include <aws/core/client/SpecifiedRetryableErrorsRetryStrategy.h>
#include <stdlib.h>
#include <thread>
#include <fstream>
#include <aws/core/utils/logging/LogMacros.h>
Expand Down Expand Up @@ -1224,3 +1220,129 @@ TEST_F(AWSCredentialsTest, TestExpiredState)
ASSERT_TRUE(credentials.IsExpired());
ASSERT_TRUE(credentials.IsExpiredOrEmpty());
}

class AWSCachedCredentialsTest : public Aws::Testing::AwsCppSdkGTestSuite
{
public:
class MockCredentialsProvider : public AWSCredentialsProvider {
public:
AWSCredentials GetAWSCredentials() override {
if (!responseQueue.empty()) {
auto creds = responseQueue.front();
responseQueue.pop();
return creds;
}
return {};
}

void PushResponse(AWSCredentials &&creds) {
responseQueue.emplace(std::forward<AWSCredentials>(creds));
}

private:
std::queue<AWSCredentials> responseQueue;
};

class MockCredentialsProviderChain : public AWSCredentialsProviderChain {
public:
void AddMockProvider(std::shared_ptr<MockCredentialsProvider> provider) {
AddProvider(provider);
}
};

void SetUp() override {
cachedProviderChain = Aws::MakeShared<MockCredentialsProviderChain>(AllocationTag);
}

std::shared_ptr<MockCredentialsProviderChain> cachedProviderChain;
};

TEST_F(AWSCachedCredentialsTest, ShouldSkipCredentialsChainForCachedValue)
{
auto failFirstProvider = Aws::MakeShared<MockCredentialsProvider>(AllocationTag);
failFirstProvider->PushResponse({});
failFirstProvider->PushResponse({"never", "see", "this"});

auto cachedProvider = Aws::MakeShared<MockCredentialsProvider>(AllocationTag);
cachedProvider->PushResponse({"sbiscigl", "was", "here"});
cachedProvider->PushResponse({"sbiscigl", "was", "here"});
cachedProvider->PushResponse({"sbiscigl", "was", "here"});
cachedProvider->PushResponse({"sbiscigl", "was", "here"});

cachedProviderChain->AddMockProvider(failFirstProvider);
cachedProviderChain->AddMockProvider(cachedProvider);

for (int i = 0; i < 4; ++i) {
auto creds = cachedProviderChain->GetAWSCredentials();
ASSERT_EQ("sbiscigl", creds.GetAWSAccessKeyId());
ASSERT_EQ("was", creds.GetAWSSecretKey());
ASSERT_EQ("here", creds.GetSessionToken());
}
}

TEST_F(AWSCachedCredentialsTest, ShouldReplaceCachedWhenProviderFails)
{
auto failFirstProvider = Aws::MakeShared<MockCredentialsProvider>(AllocationTag);
failFirstProvider->PushResponse({});
failFirstProvider->PushResponse({"and", "no", "alarms"});
failFirstProvider->PushResponse({"and", "no", "surprises"});

auto cachedFailingProvider = Aws::MakeShared<MockCredentialsProvider>(AllocationTag);
cachedFailingProvider->PushResponse({"sbiscigl", "was", "here"});
cachedFailingProvider->PushResponse({});

cachedProviderChain->AddMockProvider(failFirstProvider);
cachedProviderChain->AddMockProvider(cachedFailingProvider);

auto creds = cachedProviderChain->GetAWSCredentials();
ASSERT_EQ("sbiscigl", creds.GetAWSAccessKeyId());
ASSERT_EQ("was", creds.GetAWSSecretKey());
ASSERT_EQ("here", creds.GetSessionToken());

creds = cachedProviderChain->GetAWSCredentials();
ASSERT_EQ("and", creds.GetAWSAccessKeyId());
ASSERT_EQ("no", creds.GetAWSSecretKey());
ASSERT_EQ("alarms", creds.GetSessionToken());

creds = cachedProviderChain->GetAWSCredentials();
ASSERT_EQ("and", creds.GetAWSAccessKeyId());
ASSERT_EQ("no", creds.GetAWSSecretKey());
ASSERT_EQ("surprises", creds.GetSessionToken());
}

TEST_F(AWSCachedCredentialsTest, ShouldCacheCredenitalAsync)
{
auto cachedProvider = Aws::MakeShared<MockCredentialsProvider>(AllocationTag);
cachedProvider->PushResponse({"and", "no", "alarms"});
cachedProvider->PushResponse({"and", "no", "surprises"});

auto fallback = Aws::MakeShared<MockCredentialsProvider>(AllocationTag);
fallback->PushResponse({"a", "quiet", "life"});

cachedProviderChain->AddMockProvider(cachedProvider);
cachedProviderChain->AddMockProvider(fallback);

auto getCredentials = [](std::shared_ptr<MockCredentialsProviderChain> provider) -> AWSCredentials {
return provider->GetAWSCredentials();
};

std::vector<std::future<AWSCredentials>> futures;
futures.push_back(std::async(std::launch::async, getCredentials, cachedProviderChain));
futures.push_back(std::async(std::launch::async, getCredentials, cachedProviderChain));

std::vector<AWSCredentials> creds;
for (auto &future: futures) {
creds.push_back(future.get());
}

auto containCredentials = [](std::vector<AWSCredentials> &found,
const AWSCredentials &credentials) -> bool {
return std::any_of(found.begin(), found.end(), [&credentials](const AWSCredentials& cred) -> bool {
return cred == credentials;
});
};

ASSERT_TRUE(containCredentials(creds, {"and", "no", "alarms"}));
ASSERT_TRUE(containCredentials(creds, {"and", "no", "surprises"}));
ASSERT_FALSE(containCredentials(creds, {"a", "quiet", "life"}));
}

0 comments on commit e69f664

Please sign in to comment.