diff --git a/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h b/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h index ffcd57dde42..58bc8eb0977 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h +++ b/src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h @@ -46,8 +46,10 @@ namespace Aws void AddProvider(const std::shared_ptr& provider) { m_providerChain.push_back(provider); } - private: + private: Aws::Vector > m_providerChain; + std::shared_ptr m_cachedProvider; + mutable Aws::Utils::Threading::ReaderWriterLock m_cachedProviderLock; }; /** diff --git a/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp b/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp index 403bd380c46..81e30bacf30 100644 --- a/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp +++ b/src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp @@ -12,6 +12,7 @@ #include using namespace Aws::Auth; +using namespace Aws::Utils::Threading; static const char AWS_ECS_CONTAINER_CREDENTIALS_RELATIVE_URI[] = "AWS_CONTAINER_CREDENTIALS_RELATIVE_URI"; static const char AWS_ECS_CONTAINER_CREDENTIALS_FULL_URI[] = "AWS_CONTAINER_CREDENTIALS_FULL_URI"; @@ -21,15 +22,24 @@ static const char DefaultCredentialsProviderChainTag[] = "DefaultAWSCredentialsP AWSCredentials AWSCredentialsProviderChain::GetAWSCredentials() { + ReaderLockGuard lock(m_cachedProviderLock); + if (m_cachedProvider) { + AWSCredentials credentials = m_cachedProvider->GetAWSCredentials(); + if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty()) + { + return credentials; + } + } + lock.UpgradeToWriterLock(); for (auto&& credentialsProvider : m_providerChain) { AWSCredentials credentials = credentialsProvider->GetAWSCredentials(); if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty()) { + m_cachedProvider = credentialsProvider; return credentials; } } - return AWSCredentials(); } diff --git a/tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp b/tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp index ffaf861291a..4c6c426082b 100644 --- a/tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp +++ b/tests/aws-cpp-sdk-core-tests/aws/auth/AWSCredentialsProviderTest.cpp @@ -4,16 +4,13 @@ */ #include - #include #include #include #include #include -#include #include #include -#include #include #include #include @@ -21,7 +18,6 @@ #include #include #include -#include #include #include #include @@ -1224,3 +1220,129 @@ TEST_F(AWSCredentialsTest, TestExpiredState) ASSERT_TRUE(credentials.IsExpired()); ASSERT_TRUE(credentials.IsExpiredOrEmpty()); } + +class AWSCachedCredentialsTest : public Aws::Testing::AwsCppSdkGTestSuite +{ +public: + class MockCredentialsProvider : public AWSCredentialsProvider { + public: + AWSCredentials GetAWSCredentials() override { + if (!responseQueue.empty()) { + auto creds = responseQueue.front(); + responseQueue.pop(); + return creds; + } + return {}; + } + + void PushResponse(AWSCredentials &&creds) { + responseQueue.emplace(std::forward(creds)); + } + + private: + std::queue responseQueue; + }; + + class MockCredentialsProviderChain : public AWSCredentialsProviderChain { + public: + void AddMockProvider(std::shared_ptr provider) { + AddProvider(provider); + } + }; + + void SetUp() override { + cachedProviderChain = Aws::MakeShared(AllocationTag); + } + + std::shared_ptr cachedProviderChain; +}; + +TEST_F(AWSCachedCredentialsTest, ShouldSkipCredentialsChainForCachedValue) +{ + auto failFirstProvider = Aws::MakeShared(AllocationTag); + failFirstProvider->PushResponse({}); + failFirstProvider->PushResponse({"never", "see", "this"}); + + auto cachedProvider = Aws::MakeShared(AllocationTag); + cachedProvider->PushResponse({"sbiscigl", "was", "here"}); + cachedProvider->PushResponse({"sbiscigl", "was", "here"}); + cachedProvider->PushResponse({"sbiscigl", "was", "here"}); + cachedProvider->PushResponse({"sbiscigl", "was", "here"}); + + cachedProviderChain->AddMockProvider(failFirstProvider); + cachedProviderChain->AddMockProvider(cachedProvider); + + for (int i = 0; i < 4; ++i) { + auto creds = cachedProviderChain->GetAWSCredentials(); + ASSERT_EQ("sbiscigl", creds.GetAWSAccessKeyId()); + ASSERT_EQ("was", creds.GetAWSSecretKey()); + ASSERT_EQ("here", creds.GetSessionToken()); + } +} + +TEST_F(AWSCachedCredentialsTest, ShouldReplaceCachedWhenProviderFails) +{ + auto failFirstProvider = Aws::MakeShared(AllocationTag); + failFirstProvider->PushResponse({}); + failFirstProvider->PushResponse({"and", "no", "alarms"}); + failFirstProvider->PushResponse({"and", "no", "surprises"}); + + auto cachedFailingProvider = Aws::MakeShared(AllocationTag); + cachedFailingProvider->PushResponse({"sbiscigl", "was", "here"}); + cachedFailingProvider->PushResponse({}); + + cachedProviderChain->AddMockProvider(failFirstProvider); + cachedProviderChain->AddMockProvider(cachedFailingProvider); + + auto creds = cachedProviderChain->GetAWSCredentials(); + ASSERT_EQ("sbiscigl", creds.GetAWSAccessKeyId()); + ASSERT_EQ("was", creds.GetAWSSecretKey()); + ASSERT_EQ("here", creds.GetSessionToken()); + + creds = cachedProviderChain->GetAWSCredentials(); + ASSERT_EQ("and", creds.GetAWSAccessKeyId()); + ASSERT_EQ("no", creds.GetAWSSecretKey()); + ASSERT_EQ("alarms", creds.GetSessionToken()); + + creds = cachedProviderChain->GetAWSCredentials(); + ASSERT_EQ("and", creds.GetAWSAccessKeyId()); + ASSERT_EQ("no", creds.GetAWSSecretKey()); + ASSERT_EQ("surprises", creds.GetSessionToken()); +} + +TEST_F(AWSCachedCredentialsTest, ShouldCacheCredenitalAsync) +{ + auto cachedProvider = Aws::MakeShared(AllocationTag); + cachedProvider->PushResponse({"and", "no", "alarms"}); + cachedProvider->PushResponse({"and", "no", "surprises"}); + + auto fallback = Aws::MakeShared(AllocationTag); + fallback->PushResponse({"a", "quiet", "life"}); + + cachedProviderChain->AddMockProvider(cachedProvider); + cachedProviderChain->AddMockProvider(fallback); + + auto getCredentials = [](std::shared_ptr provider) -> AWSCredentials { + return provider->GetAWSCredentials(); + }; + + std::vector> futures; + futures.push_back(std::async(std::launch::async, getCredentials, cachedProviderChain)); + futures.push_back(std::async(std::launch::async, getCredentials, cachedProviderChain)); + + std::vector creds; + for (auto &future: futures) { + creds.push_back(future.get()); + } + + auto containCredentials = [](std::vector &found, + const AWSCredentials &credentials) -> bool { + return std::any_of(found.begin(), found.end(), [&credentials](const AWSCredentials& cred) -> bool { + return cred == credentials; + }); + }; + + ASSERT_TRUE(containCredentials(creds, {"and", "no", "alarms"})); + ASSERT_TRUE(containCredentials(creds, {"and", "no", "surprises"})); + ASSERT_FALSE(containCredentials(creds, {"a", "quiet", "life"})); +}