diff --git a/src/aws-cpp-sdk-core/CMakeLists.txt b/src/aws-cpp-sdk-core/CMakeLists.txt index ec10c63ead3..924e4f2d0a9 100644 --- a/src/aws-cpp-sdk-core/CMakeLists.txt +++ b/src/aws-cpp-sdk-core/CMakeLists.txt @@ -91,6 +91,7 @@ file(GLOB SMITHY_IDENTITY_RESOLVER_IMPL_HEADERS "include/smithy/identity/resolve file(GLOB SMITHY_IDENTITY_SIGNER_HEADERS "include/smithy/identity/signer/*.h") file(GLOB SMITHY_IDENTITY_SIGNER_BUILTIN_HEADERS "include/smithy/identity/signer/built-in/*.h") file(GLOB SMITHY_INTERCEPTOR_HEADERS "include/smithy/interceptor/*.h") +file(GLOB SMITHY_INTERCEPTOR_IMPL_HEADERS "include/smithy/interceptor/impl/*.h") file(GLOB AWS_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/source/*.cpp") file(GLOB AWS_TINYXML2_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/source/external/tinyxml2/*.cpp") @@ -363,6 +364,7 @@ file(GLOB AWS_NATIVE_SDK_COMMON_HEADERS ${SMITHY_IDENTITY_RESOLVER_BUILTIN_HEADERS} ${OPTEL_HEADERS} ${SMITHY_INTERCEPTOR_HEADERS} + ${SMITHY_INTERCEPTOR_IMPL_HEADERS} ) # misc platform-specific, not related to features (encryption/http clients) @@ -495,6 +497,7 @@ if(MSVC) source_group("Header Files\\smithy\\identity\\signer" FILES ${SMITHY_IDENTITY_SIGNER_HEADERS}) source_group("Header Files\\smithy\\identity\\signer\\built-in" FILES ${SMITHY_IDENTITY_SIGNER_BUILTIN_HEADERS}) source_group("Header Files\\smithy\\interceptor" FILES ${SMITHY_INTERCEPTOR_HEADERS}) + source_group("Header Files\\smithy\\interceptor" FILES ${SMITHY_INTERCEPTOR_IMPL_HEADERS}) # http client conditional headers if(ENABLE_CURL_CLIENT) @@ -774,6 +777,7 @@ install (FILES ${SMITHY_IDENTITY_RESOLVER_IMPL_HEADERS} DESTINATION ${INCLUDE_DI install (FILES ${SMITHY_IDENTITY_SIGNER_HEADERS} DESTINATION ${INCLUDE_DIRECTORY}/smithy/identity/signer) install (FILES ${SMITHY_IDENTITY_SIGNER_BUILTIN_HEADERS} DESTINATION ${INCLUDE_DIRECTORY}/smithy/identity/signer/built-in) install (FILES ${SMITHY_INTERCEPTOR_HEADERS} DESTINATION ${INCLUDE_DIRECTORY}/smithy/interceptor) +install (FILES ${SMITHY_INTERCEPTOR_IMPL_HEADERS} DESTINATION ${INCLUDE_DIRECTORY}/smithy/interceptor) # android logcat headers if(PLATFORM_ANDROID) diff --git a/src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h b/src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h index 19bf979d7d5..749ac31c8ae 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/client/AWSClient.h @@ -16,6 +16,7 @@ #include #include #include +#include #include #include @@ -338,11 +339,10 @@ namespace Aws */ bool AdjustClockSkew(HttpResponseOutcome& outcome, const char* signerName) const; void AddHeadersToRequest(const std::shared_ptr& httpRequest, const Http::HeaderValueCollection& headerValues) const; - void AddChecksumToRequest(const std::shared_ptr& HttpRequest, const Aws::AmazonWebServiceRequest& request) const; void AddContentBodyToRequest(const std::shared_ptr& httpRequest, const std::shared_ptr& body, bool needsContentMd5 = false, bool isChunked = false) const; void AddCommonHeaders(Aws::Http::HttpRequest& httpRequest) const; - std::shared_ptr GetBodyStream(const Aws::AmazonWebServiceRequest& request) const; + void AppendHeaderValueToRequest(const std::shared_ptr &request, String header, String value) const; std::shared_ptr m_httpClient; std::shared_ptr m_errorMarshaller; @@ -355,9 +355,7 @@ namespace Aws bool m_enableClockSkewAdjustment; Aws::String m_serviceName = "AWSBaseClient"; Aws::Client::RequestCompressionConfig m_requestCompressionConfig; - void AppendHeaderValueToRequest( - const std::shared_ptr &request, String header, - String value) const; + Aws::Vector> m_interceptors; }; AWS_CORE_API Aws::String GetAuthorizationHeader(const Aws::Http::HttpRequest& httpRequest); diff --git a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h index e47779ac0f2..ac7ef4656f6 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientAsyncRequestContext.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include @@ -70,6 +71,7 @@ namespace smithy ResponseHandlerFunc m_responseHandler; std::shared_ptr m_pExecutor; + std::shared_ptr m_interceptorContext; }; } // namespace client } // namespace smithy diff --git a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h index 85fc668d226..045c9ce7d43 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/AwsSmithyClientBase.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include #include @@ -85,7 +87,8 @@ namespace client m_serviceName(std::move(serviceName)), m_userAgent(), m_httpClient(std::move(httpClient)), - m_errorMarshaller(std::move(errorMarshaller)) + m_errorMarshaller(std::move(errorMarshaller)), + m_interceptors{Aws::MakeShared("AwsSmithyClientBase")} { if (!m_clientConfig->retryStrategy) { @@ -163,6 +166,7 @@ namespace client std::shared_ptr m_httpClient; std::shared_ptr m_errorMarshaller; + Aws::Vector> m_interceptors{}; }; } // namespace client } // namespace smithy diff --git a/src/aws-cpp-sdk-core/include/smithy/client/features/Checksums.h b/src/aws-cpp-sdk-core/include/smithy/client/features/ChecksumInterceptor.h similarity index 81% rename from src/aws-cpp-sdk-core/include/smithy/client/features/Checksums.h rename to src/aws-cpp-sdk-core/include/smithy/client/features/ChecksumInterceptor.h index 729ace43391..68aba6dee67 100644 --- a/src/aws-cpp-sdk-core/include/smithy/client/features/Checksums.h +++ b/src/aws-cpp-sdk-core/include/smithy/client/features/ChecksumInterceptor.h @@ -5,6 +5,8 @@ #pragma once +#include + #include #include #include @@ -26,7 +28,7 @@ namespace smithy static const char CHECKSUM_CONTENT_MD5_HEADER[] = "content-md5"; - class Checksums + class ChecksumInterceptor: public smithy::interceptor::Interceptor { public: using HeaderValueCollection = Aws::Http::HeaderValueCollection; @@ -38,6 +40,12 @@ namespace smithy using Sha1 = Aws::Utils::Crypto::Sha1; using PrecalculatedHash = Aws::Utils::Crypto::PrecalculatedHash; + ~ChecksumInterceptor() override = default; + ChecksumInterceptor() = default; + ChecksumInterceptor(const ChecksumInterceptor& other) = delete; + ChecksumInterceptor(ChecksumInterceptor&& other) noexcept = default; + ChecksumInterceptor& operator=(const ChecksumInterceptor& other) = delete; + ChecksumInterceptor& operator=(ChecksumInterceptor&& other) noexcept = default; static std::shared_ptr GetBodyStream(const Aws::AmazonWebServiceRequest& request) { @@ -49,9 +57,17 @@ namespace smithy return Aws::MakeShared(AWS_SMITHY_CLIENT_CHECKSUM, ""); } - static void AddChecksumToRequest(const std::shared_ptr& httpRequest, - const Aws::AmazonWebServiceRequest& request) + ModifyRequestOutcome ModifyBeforeSigning(interceptor::InterceptorContext& context) override { + const auto& httpRequest = context.GetTransmitRequest(); + const auto& request = context.GetModeledRequest(); + if (httpRequest == nullptr) + { + return Aws::Client::AWSError{Aws::Client::CoreErrors::VALIDATION, + "ValidationErrorException", + "Checksum request validation missing request", + false}; + } Aws::String checksumAlgorithmName = Aws::Utils::StringUtils::ToLower( request.GetChecksumAlgorithmName().c_str()); if (request.GetServiceSpecificParameters()) @@ -176,7 +192,7 @@ namespace smithy { std::shared_ptr crc32 = Aws::MakeShared< CRC32>(AWS_SMITHY_CLIENT_CHECKSUM); - httpRequest->AddResponseValidationHash("crc", crc32); + httpRequest->AddResponseValidationHash("crc32", crc32); } else if (checksumAlgorithmName == "sha1") { @@ -198,17 +214,21 @@ namespace smithy } } } + return httpRequest; } - using OptionalError = Aws::Crt::Optional>; - - static OptionalError ValidateResponseChecksum(AwsSmithyClientAsyncRequestContext const* const pRequestCtx, - Aws::Http::HttpResponse const* const httpResponse) + ModifyResponseOutcome ModifyBeforeDeserialization(interceptor::InterceptorContext& context) override { - assert(pRequestCtx); - assert(httpResponse); - assert(pRequestCtx->m_httpRequest); - for (const auto& hashIterator : pRequestCtx->m_httpRequest->GetResponseValidationHashes()) + const auto httpRequest = context.GetTransmitRequest(); + const auto httpResponse = context.GetTransmitResponse(); + if (httpRequest == nullptr || httpResponse == nullptr) + { + return Aws::Client::AWSError{Aws::Client::CoreErrors::VALIDATION, + "ValidationErrorException", + "Checksum response validation missing request or response", + false}; + } + for (const auto& hashIterator : httpRequest->GetResponseValidationHashes()) { Aws::String checksumHeaderKey = Aws::String("x-amz-checksum-") + hashIterator.first; // TODO: If checksum ends with -#, then skip @@ -218,24 +238,23 @@ namespace smithy if (HashingUtils::Base64Encode(hashIterator.second->GetHash().GetResult()) != checksumHeaderValue) { - auto error = OptionalError( - Aws::Client::AWSError( + auto error = Aws::Client::AWSError{ Aws::Client::CoreErrors::VALIDATION, "", "Response checksums mismatch", - false/*retryable*/)); - error->SetResponseHeaders(httpResponse->GetHeaders()); - error->SetResponseCode(httpResponse->GetResponseCode()); - error->SetRemoteHostIpAddress( - httpResponse->GetOriginatingRequest().GetResolvedRemoteHost()); - - AWS_LOGSTREAM_ERROR(AWS_SMITHY_CLIENT_CHECKSUM, *error); - return error; + false/*retryable*/}; + error.SetResponseHeaders(httpResponse->GetHeaders()); + error.SetResponseCode(httpResponse->GetResponseCode()); + error.SetRemoteHostIpAddress( + httpResponse->GetOriginatingRequest().GetResolvedRemoteHost()); + + AWS_LOGSTREAM_ERROR(AWS_SMITHY_CLIENT_CHECKSUM, error); + return {error}; } // Validate only a single checksum returned in an HTTP response break; } } - return {}; + return httpResponse; } }; } diff --git a/src/aws-cpp-sdk-core/include/smithy/interceptor/Interceptor.h b/src/aws-cpp-sdk-core/include/smithy/interceptor/Interceptor.h index ce5c4a4244e..b133857624b 100644 --- a/src/aws-cpp-sdk-core/include/smithy/interceptor/Interceptor.h +++ b/src/aws-cpp-sdk-core/include/smithy/interceptor/Interceptor.h @@ -15,10 +15,10 @@ namespace smithy virtual ~Interceptor() = default; using ModifyRequestOutcome = Aws::Utils::Outcome, Aws::Client::AWSError>; - virtual ModifyRequestOutcome ModifyRequest(InterceptorContext& context) = 0; + virtual ModifyRequestOutcome ModifyBeforeSigning(InterceptorContext& context) = 0; using ModifyResponseOutcome = Aws::Utils::Outcome, Aws::Client::AWSError>; - virtual ModifyResponseOutcome ModifyResponse(InterceptorContext& context) = 0; + virtual ModifyResponseOutcome ModifyBeforeDeserialization(InterceptorContext& context) = 0; }; } } diff --git a/src/aws-cpp-sdk-core/include/smithy/interceptor/InterceptorContext.h b/src/aws-cpp-sdk-core/include/smithy/interceptor/InterceptorContext.h index a5abb40061b..08cfc836f37 100644 --- a/src/aws-cpp-sdk-core/include/smithy/interceptor/InterceptorContext.h +++ b/src/aws-cpp-sdk-core/include/smithy/interceptor/InterceptorContext.h @@ -4,10 +4,10 @@ */ #pragma once #include -#include +#include +#include #include #include -#include #include namespace smithy @@ -17,78 +17,57 @@ namespace smithy class InterceptorContext { public: - InterceptorContext() = default; + explicit InterceptorContext(const Aws::AmazonWebServiceRequest& m_modeled_request) + : m_modeledRequest(m_modeled_request) + { + } + virtual ~InterceptorContext() = default; InterceptorContext(const InterceptorContext& other) = delete; - InterceptorContext(InterceptorContext&& other) noexcept = default; + InterceptorContext(InterceptorContext&& other) noexcept = delete; InterceptorContext& operator=(const InterceptorContext& other) = delete; - InterceptorContext& operator=(InterceptorContext&& other) noexcept = default; + InterceptorContext& operator=(InterceptorContext&& other) noexcept = delete; + + const Aws::AmazonWebServiceRequest& GetModeledRequest() const + { + return m_modeledRequest; + } - using GetRequestOutcome = Aws::Utils::Outcome, Aws::Client::AWSError>; - GetRequestOutcome GetRequest() const + std::shared_ptr GetTransmitRequest() const { - if (!m_request) - { - return Aws::Client::AWSError{ - Aws::Client::CoreErrors::RESOURCE_NOT_FOUND, - "ResourceNotFoundException", - "Request is NULL", - false - }; - } - return m_request; + return m_transmitRequest; } - void SetRequest(const std::shared_ptr& request) + void SetTransmitRequest(const std::shared_ptr& transmitRequest) { - this->m_request = request; + m_transmitRequest = transmitRequest; } - using GetResponseOutcome = Aws::Utils::Outcome, Aws::Client::AWSError>; - GetResponseOutcome GetResponse() const + std::shared_ptr GetTransmitResponse() const { - if (!m_response) - { - return Aws::Client::AWSError{ - Aws::Client::CoreErrors::RESOURCE_NOT_FOUND, - "ResourceNotFoundException", - "Response is NULL", - false - }; - } - return m_response; + return m_transmitResponse; } - void SetResponse(const std::shared_ptr& response) + void SetTransmitResponse(const std::shared_ptr& transmitResponse) { - this->m_response = response; + m_transmitResponse = transmitResponse; } - using GetAttributeOutcome = Aws::Utils::Outcome>; - GetAttributeOutcome GetAttribute(const Aws::String& attribute) + Aws::String GetAttribute(const Aws::String& key) const { - const auto attribute_iter = m_attributes.find(attribute); - if (attribute_iter == m_attributes.end()) - { - return Aws::Client::AWSError{ - Aws::Client::CoreErrors::RESOURCE_NOT_FOUND, - "ResourceNotFoundException", - "Attribute not found", - false - }; - } - return attribute_iter->second; + return m_attributes.at(key); } - void SetAttribute(const Aws::String& attribute, const Aws::String& value) + void SetAttribute(const Aws::String& key, const Aws::String& value) { - m_attributes.emplace(attribute, value); + m_attributes.insert({key, value}); } private: Aws::Map m_attributes{}; - std::shared_ptr m_request{nullptr}; - std::shared_ptr m_response{nullptr}; + const Aws::AmazonWebServiceRequest& m_modeledRequest; + std::shared_ptr m_transmitRequest{nullptr}; + std::shared_ptr m_transmitResponse{nullptr}; }; } } diff --git a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp index f7b728b134c..7d33207e9de 100644 --- a/src/aws-cpp-sdk-core/source/client/AWSClient.cpp +++ b/src/aws-cpp-sdk-core/source/client/AWSClient.cpp @@ -45,6 +45,7 @@ #include #include +#include #include #include @@ -57,6 +58,7 @@ using namespace Aws::Utils; using namespace Aws::Utils::Json; using namespace Aws::Utils::Xml; using namespace smithy::components::tracing; +using namespace smithy::interceptor; static const int SUCCESS_RESPONSE_MIN = 200; static const int SUCCESS_RESPONSE_MAX = 299; @@ -136,7 +138,8 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration, m_hash(Aws::Utils::Crypto::CreateMD5Implementation()), m_requestTimeoutMs(configuration.requestTimeoutMs), m_enableClockSkewAdjustment(configuration.enableClockSkewAdjustment), - m_requestCompressionConfig(configuration.requestCompressionConfig) + m_requestCompressionConfig(configuration.requestCompressionConfig), + m_interceptors{Aws::MakeShared(AWS_CLIENT_LOG_TAG)} { } @@ -163,6 +166,7 @@ AWSClient::AWSClient(const Aws::Client::ClientConfiguration& configuration, m_enableClockSkewAdjustment(configuration.enableClockSkewAdjustment), m_requestCompressionConfig(configuration.requestCompressionConfig) { + m_interceptors.emplace_back(Aws::MakeUnique(AWS_CLIENT_LOG_TAG)); } void AWSClient::DisableRequestProcessing() @@ -543,6 +547,17 @@ HttpResponseOutcome AWSClient::AttemptOneRequest(const std::shared_ptrgetMeter(this->GetServiceClientName(), {}), {{TracingUtils::SMITHY_METHOD_DIMENSION, request.GetServiceRequestName()},{TracingUtils::SMITHY_SERVICE_DIMENSION, this->GetServiceClientName()}}); + InterceptorContext context{request}; + context.SetTransmitRequest(httpRequest); + for (const auto& interceptor : m_interceptors) + { + const auto modifiedRequest = interceptor->ModifyBeforeSigning(context); + if (!modifiedRequest.IsSuccess()) + { + return modifiedRequest.GetError(); + } + } + auto signer = GetSignerByName(signerName); auto signedRequest = TracingUtils::MakeCallWithTiming([&]() -> bool { return signer->SignRequest(*httpRequest, signerRegionOverride, signerServiceNameOverride, true); @@ -570,27 +585,13 @@ HttpResponseOutcome AWSClient::AttemptOneRequest(const std::shared_ptrgetMeter(this->GetServiceClientName(), {}), {{TracingUtils::SMITHY_METHOD_DIMENSION, request.GetServiceRequestName()},{TracingUtils::SMITHY_SERVICE_DIMENSION, this->GetServiceClientName()}}); - if (request.ShouldValidateResponseChecksum()) + context.SetTransmitResponse(httpResponse); + for (const auto& interceptor : m_interceptors) { - for (const auto& hashIterator : httpRequest->GetResponseValidationHashes()) + const auto modifiedRequest = interceptor->ModifyBeforeDeserialization(context); + if (!modifiedRequest.IsSuccess()) { - Aws::String checksumHeaderKey = Aws::String("x-amz-checksum-") + hashIterator.first; - // TODO: If checksum ends with -#, then skip - if (httpResponse->HasHeader(checksumHeaderKey.c_str())) - { - Aws::String checksumHeaderValue = httpResponse->GetHeader(checksumHeaderKey.c_str()); - if (HashingUtils::Base64Encode(hashIterator.second->GetHash().GetResult()) != checksumHeaderValue) - { - AWSError error(CoreErrors::VALIDATION, "", "Response checksums mismatch", false/*retryable*/); - error.SetResponseHeaders(httpResponse->GetHeaders()); - error.SetResponseCode(httpResponse->GetResponseCode()); - error.SetRemoteHostIpAddress(httpResponse->GetOriginatingRequest().GetResolvedRemoteHost()); - AWS_LOGSTREAM_ERROR(AWS_CLIENT_LOG_TAG, error); - return HttpResponseOutcome(error); - } - // Validate only a single checksum returned in an HTTP response - break; - } + return modifiedRequest.GetError(); } } @@ -796,129 +797,6 @@ void AWSClient::AppendHeaderValueToRequest(const std::shared_ptr& httpRequest, - const Aws::AmazonWebServiceRequest& request) const -{ - Aws::String checksumAlgorithmName = Aws::Utils::StringUtils::ToLower(request.GetChecksumAlgorithmName().c_str()); - if (request.GetServiceSpecificParameters()) { - auto requestChecksumOverride = request.GetServiceSpecificParameters()->parameterMap.find("overrideChecksum"); - if (requestChecksumOverride != request.GetServiceSpecificParameters()->parameterMap.end()) { - checksumAlgorithmName = requestChecksumOverride->second; - } - } - - bool shouldSkipChecksum = request.GetServiceSpecificParameters() && - request.GetServiceSpecificParameters()->parameterMap.find("overrideChecksumDisable") != - request.GetServiceSpecificParameters()->parameterMap.end(); - - //Check if user has provided the checksum algorithm - if (!checksumAlgorithmName.empty() && !shouldSkipChecksum) - { - // Check if user has provided a checksum value for the specified algorithm - const Aws::String checksumType = "x-amz-checksum-" + checksumAlgorithmName; - const HeaderValueCollection &headers = request.GetHeaders(); - const auto checksumHeader = headers.find(checksumType); - bool checksumValueAndAlgorithmProvided = checksumHeader != headers.end(); - - // For non-streaming payload, the resolved checksum location is always header. - // For streaming payload, the resolved checksum location depends on whether it is an unsigned payload, we let AwsAuthSigner decide it. - if (request.IsStreaming() && checksumValueAndAlgorithmProvided) - { - const auto hash = Aws::MakeShared(AWS_CLIENT_LOG_TAG, checksumHeader->second); - httpRequest->SetRequestHash(checksumAlgorithmName,hash); - } - else if (checksumValueAndAlgorithmProvided){ - httpRequest->SetHeaderValue(checksumType, checksumHeader->second); - } - else if (checksumAlgorithmName == "crc32") - { - if (request.IsStreaming()) - { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_CLIENT_LOG_TAG)); - } - else - { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateCRC32(*(GetBodyStream(request))))); - } - } - else if (checksumAlgorithmName == "crc32c") - { - if (request.IsStreaming()) - { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_CLIENT_LOG_TAG)); - } - else - { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateCRC32C(*(GetBodyStream(request))))); - } - } - else if (checksumAlgorithmName == "sha256") - { - if (request.IsStreaming()) - { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_CLIENT_LOG_TAG)); - } - else - { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateSHA256(*(GetBodyStream(request))))); - } - } - else if (checksumAlgorithmName == "sha1") - { - if (request.IsStreaming()) - { - httpRequest->SetRequestHash(checksumAlgorithmName, Aws::MakeShared(AWS_CLIENT_LOG_TAG)); - } - else - { - httpRequest->SetHeaderValue(checksumType, HashingUtils::Base64Encode(HashingUtils::CalculateSHA1(*(GetBodyStream(request))))); - } - } - else if (checksumAlgorithmName == "md5" && headers.find(CONTENT_MD5_HEADER) == headers.end()) - { - httpRequest->SetHeaderValue(Http::CONTENT_MD5_HEADER, HashingUtils::Base64Encode(HashingUtils::CalculateMD5(*(GetBodyStream(request))))); - } - else if (headers.find(CONTENT_MD5_HEADER) == headers.end()) - { - AWS_LOGSTREAM_WARN(AWS_CLIENT_LOG_TAG, "Checksum algorithm: " << checksumAlgorithmName << " is not supported by SDK."); - } - } - - // Response checksums - if (request.ShouldValidateResponseChecksum()) - { - for (const Aws::String& responseChecksumAlgorithmName : request.GetResponseChecksumAlgorithmNames()) - { - checksumAlgorithmName = Aws::Utils::StringUtils::ToLower(responseChecksumAlgorithmName.c_str()); - - if (checksumAlgorithmName == "crc32c") - { - std::shared_ptr crc32c = Aws::MakeShared(AWS_CLIENT_LOG_TAG); - httpRequest->AddResponseValidationHash("crc32c", crc32c); - } - else if (checksumAlgorithmName == "crc32") - { - std::shared_ptr crc32 = Aws::MakeShared(AWS_CLIENT_LOG_TAG); - httpRequest->AddResponseValidationHash("crc", crc32); - } - else if (checksumAlgorithmName == "sha1") - { - std::shared_ptr sha1 = Aws::MakeShared(AWS_CLIENT_LOG_TAG); - httpRequest->AddResponseValidationHash("sha1", sha1); - } - else if (checksumAlgorithmName == "sha256") - { - std::shared_ptr sha256 = Aws::MakeShared(AWS_CLIENT_LOG_TAG); - httpRequest->AddResponseValidationHash("sha256", sha256); - } - else - { - AWS_LOGSTREAM_WARN(AWS_CLIENT_LOG_TAG, "Checksum algorithm: " << checksumAlgorithmName << " is not supported in validating response body yet."); - } - } - } -} - void AWSClient::AddContentBodyToRequest(const std::shared_ptr& httpRequest, const std::shared_ptr& body, bool needsContentMd5, bool isChunked) const { @@ -1034,14 +912,12 @@ void AWSClient::BuildHttpRequest(const Aws::AmazonWebServiceRequest& request, co } } - AddChecksumToRequest(httpRequest, request); // Pass along handlers for processing data sent/received in bytes httpRequest->SetHeadersReceivedEventHandler(request.GetHeadersReceivedEventHandler()); httpRequest->SetDataReceivedEventHandler(request.GetDataReceivedEventHandler()); httpRequest->SetDataSentEventHandler(request.GetDataSentEventHandler()); httpRequest->SetContinueRequestHandle(request.GetContinueRequestHandler()); httpRequest->SetServiceSpecificParameters(request.GetServiceSpecificParameters()); - request.AddQueryStringParameters(httpRequest->GetUri()); } @@ -1133,14 +1009,6 @@ Aws::String AWSClient::GeneratePresignedUrl(const Aws::AmazonWebServiceRequest& return AWSUrlPresigner(*this).GeneratePresignedUrl(request, uri, method, extraParams, expirationInSeconds, serviceSpecificParameter); } -std::shared_ptr AWSClient::GetBodyStream(const Aws::AmazonWebServiceRequest& request) const { - if (request.GetBody() != nullptr) { - return request.GetBody(); - } - // Return an empty string stream for no body - return Aws::MakeShared(AWS_CLIENT_LOG_TAG, ""); -} - std::shared_ptr AWSClient::MakeHttpRequest(std::shared_ptr& request) const { return m_httpClient->MakeRequest(request, m_readRateLimiter.get(), m_writeRateLimiter.get()); diff --git a/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp b/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp index c361d9b8a2b..05eb3b6ac24 100644 --- a/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp +++ b/src/aws-cpp-sdk-core/source/smithy/client/AwsSmithyClientBase.cpp @@ -5,7 +5,6 @@ #include #include -#include #include #include @@ -20,6 +19,7 @@ #include "smithy/tracing/TracingUtils.h" using namespace smithy::client; +using namespace smithy::interceptor; using namespace smithy::components::tracing; static const char AWS_SMITHY_CLIENT_LOG[] = "AwsSmithyClient"; @@ -81,7 +81,6 @@ AwsSmithyClientBase::BuildHttpRequest(const std::shared_ptrSetHeadersReceivedEventHandler(pRequest->GetHeadersReceivedEventHandler()); httpRequest->SetDataReceivedEventHandler(pRequest->GetDataReceivedEventHandler()); @@ -167,6 +166,7 @@ void AwsSmithyClientBase::MakeRequestAsync(Aws::AmazonWebServiceRequest const* c } pRequestCtx->m_requestInfo.attempt = 1; pRequestCtx->m_requestInfo.maxAttempts = 0; + pRequestCtx->m_interceptorContext = Aws::MakeShared(AWS_SMITHY_CLIENT_LOG, *request); AttemptOneRequestAsync(std::move(pRequestCtx)); } @@ -202,6 +202,20 @@ void AwsSmithyClientBase::AttemptOneRequestAsync(std::shared_ptrm_interceptorContext->SetTransmitRequest(pRequestCtx->m_httpRequest); + for (const auto& interceptor : m_interceptors) + { + auto modifiedRequest = interceptor->ModifyBeforeSigning(*pRequestCtx->m_interceptorContext); + if (!modifiedRequest.IsSuccess()) + { + pExecutor->Submit([modifiedRequest, responseHandler]() mutable + { + responseHandler(modifiedRequest.GetError()); + }); + return; + } + } + Aws::Monitoring::CoreMetricsCollection coreMetrics; pRequestCtx->m_monitoringContexts = Aws::Monitoring::OnRequestStarted(this->GetServiceClientName(), pRequestCtx->m_requestName, @@ -290,14 +304,15 @@ void AwsSmithyClientBase::HandleAsyncReply(std::shared_ptrm_pRequest && pRequestCtx->m_pRequest->ShouldValidateResponseChecksum()) + pRequestCtx->m_interceptorContext->SetTransmitResponse(httpResponse); + for (const auto& interceptor : m_interceptors) { - auto checksumError = Checksums::ValidateResponseChecksum(pRequestCtx.get(), httpResponse.get()); - if (checksumError) + const auto modifiedResponse = interceptor->ModifyBeforeDeserialization(*pRequestCtx->m_interceptorContext); + if (!modifiedResponse.IsSuccess()) { - return pRequestCtx->m_responseHandler(HttpResponseOutcome(std::move(*checksumError))); + return pRequestCtx->m_responseHandler(HttpResponseOutcome(modifiedResponse.GetError())); } - } + }; bool hasEmbeddedError = pRequestCtx->m_pRequest && pRequestCtx->m_pRequest->HasEmbeddedError(httpResponse->GetResponseBody(), httpResponse->GetHeaders()); diff --git a/tests/aws-cpp-sdk-core-tests/CMakeLists.txt b/tests/aws-cpp-sdk-core-tests/CMakeLists.txt index 3928ebdd977..90406667b9a 100644 --- a/tests/aws-cpp-sdk-core-tests/CMakeLists.txt +++ b/tests/aws-cpp-sdk-core-tests/CMakeLists.txt @@ -26,8 +26,8 @@ file(GLOB MONITORING_SRC "${CMAKE_CURRENT_SOURCE_DIR}/monitoring/*.cpp") file(GLOB SMITHY_TRACING_SRC "${CMAKE_CURRENT_SOURCE_DIR}/smithy/tracing/*.cpp") file(GLOB SMITHY_CLIENT_SRC "${CMAKE_CURRENT_SOURCE_DIR}/smithy/client/*.cpp") file(GLOB SMITHY_CLIENT_SERIALIZER_SRC "${CMAKE_CURRENT_SOURCE_DIR}/smithy/client/serializer/*.cpp") +file(GLOB SMITHY_CLIENT_FEATURE_SRC "${CMAKE_CURRENT_SOURCE_DIR}/smithy/client/feature/*.cpp") file(GLOB ENDPOINT_SRC "${CMAKE_CURRENT_SOURCE_DIR}/endpoint/*.cpp") -file(GLOB SMITHY_INTERCEPTOR_SRC "${CMAKE_CURRENT_SOURCE_DIR}/smithy/interceptor/*.cpp") file(GLOB AWS_CPP_SDK_CORE_TESTS_SRC @@ -54,6 +54,7 @@ file(GLOB AWS_CPP_SDK_CORE_TESTS_SRC ${SMITHY_TRACING_SRC} ${SMITHY_CLIENT_SRC} ${SMITHY_CLIENT_SERIALIZER_SRC} + ${SMITHY_CLIENT_FEATURE_SRC} ${ENDPOINT_SRC} ${SMITHY_INTERCEPTOR_SRC} ) @@ -80,8 +81,8 @@ if(PLATFORM_WINDOWS) source_group("Source Files\\utils\\threading" FILES ${UTILS_THREADING_SRC}) source_group("Source Files\\smithy\\tracing" FILES ${SMITHY_TRACING_SRC}) source_group("Source Files\\smithy\\client" FILES ${SMITHY_CLIENT_SRC}) - source_group("Source Files\\smithy\\client" FILES ${SMITHY_CLIENT_SERIALIZER_SRC}) - source_group("Source Files\\smithy\\client" FILES ${SMITHY_INTERCEPTOR_SRC}) + source_group("Source Files\\smithy\\client\\serializer" FILES ${SMITHY_CLIENT_SERIALIZER_SRC}) + source_group("Source Files\\smithy\\client\\feature" FILES ${SMITHY_CLIENT_FEATURE_SRC}) endif() endif() diff --git a/tests/aws-cpp-sdk-core-tests/smithy/client/feature/ChecksumInterceptorTest.cpp b/tests/aws-cpp-sdk-core-tests/smithy/client/feature/ChecksumInterceptorTest.cpp new file mode 100644 index 00000000000..f4fe4afb5d1 --- /dev/null +++ b/tests/aws-cpp-sdk-core-tests/smithy/client/feature/ChecksumInterceptorTest.cpp @@ -0,0 +1,181 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ +#include +#include +#include + +using namespace smithy::interceptor; +using namespace smithy::client; +using namespace Aws; +using namespace Aws::Http; +using namespace Aws::Http::Standard; +using namespace Aws::Testing; + +static const char* ALLOC_TAG = "ChecksumInterceptorTest"; + +class ChecksumInterceptorTest : public AwsCppSdkGTestSuite +{ +protected: + ChecksumInterceptor m_interceptor; +}; + +class MockChecksumRequest: public AmazonWebServiceRequest +{ +public: + explicit MockChecksumRequest(const Aws::String& m_response_body, + const Aws::String& checksumAlgorithmName = "", + const bool shouldAwsChunk = false, + const bool shouldValidateResponseChecksum = false, + const Aws::Vector& responseValidationChecksums = {}) + : m_responseBody(m_response_body), + m_checksumAlgorithmName{checksumAlgorithmName}, + m_shouldAwsChunk{shouldAwsChunk}, + m_shouldValidateResponseChecksum{shouldValidateResponseChecksum}, + m_responseChecksumsToValidate{responseValidationChecksums} + { + } + + ~MockChecksumRequest() override = default; + + std::shared_ptr GetBody() const override + { + return Aws::MakeShared(ALLOC_TAG, m_responseBody); + } + + Aws::Http::HeaderValueCollection GetHeaders() const override + { + return {}; + } + + const char* GetServiceRequestName() const override + { + return "LeblancCafeService"; + } + + inline Aws::String GetChecksumAlgorithmName() const override + { + return m_checksumAlgorithmName; + } + + bool IsStreaming() const override + { + return m_shouldAwsChunk; + } + + inline bool ShouldValidateResponseChecksum() const override + { + return m_shouldValidateResponseChecksum; + } + + inline Aws::Vector GetResponseChecksumAlgorithmNames() const override + { + return m_responseChecksumsToValidate; + } + +private: + Aws::String m_responseBody; + Aws::String m_checksumAlgorithmName{}; + bool m_shouldAwsChunk{false}; + bool m_shouldValidateResponseChecksum{false}; + Aws::Vector m_responseChecksumsToValidate{}; +}; + +TEST_F(ChecksumInterceptorTest, MissingRequestInContextShouldReturnError) +{ + MockChecksumRequest request{"Take your time"}; + InterceptorContext context{request}; + const auto outcome = m_interceptor.ModifyBeforeSigning(context); + EXPECT_FALSE(outcome.IsSuccess()); + EXPECT_EQ(Client::CoreErrors::VALIDATION, outcome.GetError().GetErrorType()); + EXPECT_EQ("Checksum request validation missing request", outcome.GetError().GetMessage()); + EXPECT_EQ("ValidationErrorException", outcome.GetError().GetExceptionName()); +} + +TEST_F(ChecksumInterceptorTest, MissingResponseInContextShouldReturnError) +{ + MockChecksumRequest request{"Take your time"}; + InterceptorContext context{request}; + const auto outcome = m_interceptor.ModifyBeforeDeserialization(context); + EXPECT_FALSE(outcome.IsSuccess()); + EXPECT_EQ(Client::CoreErrors::VALIDATION, outcome.GetError().GetErrorType()); + EXPECT_EQ("Checksum response validation missing request or response", outcome.GetError().GetMessage()); + EXPECT_EQ("ValidationErrorException", outcome.GetError().GetExceptionName()); +} + +TEST_F(ChecksumInterceptorTest, ChecksumInterceptorShouldAddTrailingChecksumToRequest) +{ + MockChecksumRequest modeledRequest{"Take your time", "crc32", true}; + InterceptorContext context{modeledRequest}; + URI uri{"https://www.persona5.com/joker"}; + std::shared_ptr request(CreateHttpRequest(uri, HttpMethod::HTTP_GET, Utils::Stream::DefaultResponseStreamFactoryMethod)); + context.SetTransmitRequest(request); + const auto outcome = m_interceptor.ModifyBeforeSigning(context); + EXPECT_TRUE(outcome.IsSuccess()); + EXPECT_EQ("crc32", outcome.GetResult()->GetRequestHash().first); +} + +TEST_F(ChecksumInterceptorTest, ChecksumInterceptorShouldAddHeaderChecksumToRequest) +{ + MockChecksumRequest modeledRequest{"Take your time", "crc32", false}; + InterceptorContext context{modeledRequest}; + URI uri{"https://www.persona5.com/joker"}; + std::shared_ptr request(CreateHttpRequest(uri, HttpMethod::HTTP_GET, Utils::Stream::DefaultResponseStreamFactoryMethod)); + context.SetTransmitRequest(request); + const auto outcome = m_interceptor.ModifyBeforeSigning(context); + EXPECT_TRUE(outcome.IsSuccess()); + EXPECT_EQ("", outcome.GetResult()->GetRequestHash().first); + EXPECT_EQ(nullptr, outcome.GetResult()->GetRequestHash().second); + EXPECT_EQ(2ul, outcome.GetResult()->GetHeaders().size()); + EXPECT_EQ("www.persona5.com", outcome.GetResult()->GetHeaderValue("host")); + EXPECT_EQ("KQcztA==", outcome.GetResult()->GetHeaderValue("x-amz-checksum-crc32")); +} + +TEST_F(ChecksumInterceptorTest, ChecksumInterceptorShouldValidateCorrectResponseChecksum) +{ + Aws::Vector responseValidationChecksumsToValidate{"crc32"}; + MockChecksumRequest modeledRequest{"Take your time", "crc32", true, true, responseValidationChecksumsToValidate}; + InterceptorContext context{modeledRequest}; + URI uri{"https://www.persona5.com/joker"}; + std::shared_ptr request(CreateHttpRequest(uri, HttpMethod::HTTP_GET, Utils::Stream::DefaultResponseStreamFactoryMethod)); + context.SetTransmitRequest(request); + const auto requestOutcome = m_interceptor.ModifyBeforeSigning(context); + auto responseHashes = request->GetResponseValidationHashes(); + EXPECT_EQ(1ul, responseHashes.size()); + EXPECT_EQ("crc32", responseHashes[0].first); + EXPECT_NE(nullptr, responseHashes[0].second); + auto bodyStr = Crt::ByteBufFromCString("Take your time"); + responseHashes[0].second->Update(bodyStr.buffer, bodyStr.len); + EXPECT_TRUE(requestOutcome.IsSuccess()); + std::shared_ptr response = Aws::MakeShared(ALLOC_TAG, request); + response->AddHeader("x-amz-checksum-crc32", "KQcztA=="); + context.SetTransmitResponse(response); + const auto responseOutcome = m_interceptor.ModifyBeforeDeserialization(context); + EXPECT_TRUE(responseOutcome.IsSuccess()); +} + +TEST_F(ChecksumInterceptorTest, ChecksumInterceptorShouldValidateBadResponseChecksum) +{ + Aws::Vector responseValidationChecksumsToValidate{"crc32"}; + MockChecksumRequest modeledRequest{"Take your time", "crc32", true, true, responseValidationChecksumsToValidate}; + InterceptorContext context{modeledRequest}; + URI uri{"https://www.persona5.com/joker"}; + std::shared_ptr request(CreateHttpRequest(uri, HttpMethod::HTTP_GET, Utils::Stream::DefaultResponseStreamFactoryMethod)); + context.SetTransmitRequest(request); + const auto requestOutcome = m_interceptor.ModifyBeforeSigning(context); + auto responseHashes = request->GetResponseValidationHashes(); + EXPECT_EQ(1ul, responseHashes.size()); + EXPECT_EQ("crc32", responseHashes[0].first); + EXPECT_NE(nullptr, responseHashes[0].second); + auto bodyStr = Crt::ByteBufFromCString("Take your ~corrupted~ time"); + responseHashes[0].second->Update(bodyStr.buffer, bodyStr.len); + EXPECT_TRUE(requestOutcome.IsSuccess()); + std::shared_ptr response = Aws::MakeShared(ALLOC_TAG, request); + response->AddHeader("x-amz-checksum-crc32", "KQcztA=="); + context.SetTransmitResponse(response); + const auto responseOutcome = m_interceptor.ModifyBeforeDeserialization(context); + EXPECT_FALSE(responseOutcome.IsSuccess()); + EXPECT_EQ(Client::CoreErrors::VALIDATION, responseOutcome.GetError().GetErrorType()); + EXPECT_EQ("Response checksums mismatch", responseOutcome.GetError().GetMessage()); +} diff --git a/tests/aws-cpp-sdk-core-tests/smithy/interceptor/InterceptorTest.cpp b/tests/aws-cpp-sdk-core-tests/smithy/interceptor/InterceptorTest.cpp index ef47ff5d909..2378d5a3792 100644 --- a/tests/aws-cpp-sdk-core-tests/smithy/interceptor/InterceptorTest.cpp +++ b/tests/aws-cpp-sdk-core-tests/smithy/interceptor/InterceptorTest.cpp @@ -16,6 +16,8 @@ using namespace Aws::Utils; using namespace Aws::Utils::Stream; using namespace Aws::Testing; +const char* ALLOCATION_TAG = "SmithyInterceptorTest"; + class SmithyInterceptorTest : public AwsCppSdkGTestSuite { }; @@ -26,16 +28,16 @@ class MockSuccessInterceptor : public Interceptor MockSuccessInterceptor() = default; ~MockSuccessInterceptor() override = default; - ModifyRequestOutcome ModifyRequest(InterceptorContext& context) override + ModifyRequestOutcome ModifyBeforeSigning(InterceptorContext& context) override { context.SetAttribute("MockInterceptorRequest", "Called"); - return context.GetRequest(); + return context.GetTransmitRequest(); } - ModifyResponseOutcome ModifyResponse(InterceptorContext& context) override + ModifyResponseOutcome ModifyBeforeDeserialization(InterceptorContext& context) override { context.SetAttribute("MockInterceptorResponse", "Called"); - return context.GetResponse(); + return context.GetTransmitResponse(); } }; @@ -45,7 +47,7 @@ class MockRequestFailureInterceptor : public Interceptor MockRequestFailureInterceptor() = default; ~MockRequestFailureInterceptor() override = default; - ModifyRequestOutcome ModifyRequest(InterceptorContext& context) override + ModifyRequestOutcome ModifyBeforeSigning(InterceptorContext& context) override { context.SetAttribute("MockInterceptorRequest", "Called"); return Aws::Client::AWSError{ @@ -56,10 +58,10 @@ class MockRequestFailureInterceptor : public Interceptor };; } - ModifyResponseOutcome ModifyResponse(InterceptorContext& context) override + ModifyResponseOutcome ModifyBeforeDeserialization(InterceptorContext& context) override { context.SetAttribute("MockInterceptorResponse", "Called"); - return context.GetResponse(); + return context.GetTransmitResponse(); } }; @@ -69,13 +71,13 @@ class MockResponseFailureInterceptor : public Interceptor MockResponseFailureInterceptor() = default; ~MockResponseFailureInterceptor() override = default; - ModifyRequestOutcome ModifyRequest(InterceptorContext& context) override + ModifyRequestOutcome ModifyBeforeSigning(InterceptorContext& context) override { context.SetAttribute("MockInterceptorRequest", "Called"); - return context.GetRequest(); + return context.GetTransmitRequest(); } - ModifyResponseOutcome ModifyResponse(InterceptorContext& context) override + ModifyResponseOutcome ModifyBeforeDeserialization(InterceptorContext& context) override { context.SetAttribute("MockInterceptorResponse", "Called"); return Aws::Client::AWSError{ @@ -87,6 +89,35 @@ class MockResponseFailureInterceptor : public Interceptor } }; +class MockInterceptorRequest: public AmazonWebServiceRequest +{ +public: + explicit MockInterceptorRequest(const Aws::String& m_response_body) + : m_responseBody(m_response_body) + { + } + + ~MockInterceptorRequest() override = default; + + std::shared_ptr GetBody() const override + { + return Aws::MakeShared(ALLOCATION_TAG, m_responseBody); + } + + Aws::Http::HeaderValueCollection GetHeaders() const override + { + return {}; + } + + const char* GetServiceRequestName() const override + { + return "LeblancCafeService"; + } + +private: + Aws::String m_responseBody; +}; + class MockClient { public: @@ -104,28 +135,29 @@ class MockClient } using RequestOutcome = Outcome, AWSError>; - RequestOutcome MakeRequest(const std::shared_ptr& request, InterceptorContext& context) const + RequestOutcome MakeRequest(InterceptorContext& context, + const std::shared_ptr& request) const { - context.SetRequest(request); + context.SetTransmitRequest(request); for (const auto& interceptor: m_interceptors) { - const auto modifiedRequest = interceptor->ModifyRequest(context); + const auto modifiedRequest = interceptor->ModifyBeforeSigning(context); if (!modifiedRequest.IsSuccess()) { return modifiedRequest.GetError(); } } - auto response = Aws::MakeShared("SmithyInterceptorTest", request); - context.SetResponse(response); + auto response = Aws::MakeShared(ALLOCATION_TAG, request); + context.SetTransmitResponse(response); for (const auto& interceptor: m_interceptors) { - const auto modifiedResponse = interceptor->ModifyResponse(context); + const auto modifiedResponse = interceptor->ModifyBeforeDeserialization(context); if (!modifiedResponse.IsSuccess()) { return modifiedResponse.GetError(); } } - return context.GetResponse(); + return context.GetTransmitResponse(); } private: @@ -141,37 +173,39 @@ TEST_F(SmithyInterceptorTest, MockInterceptorShouldReturnSuccess) { const auto uri = "https://www.villagepsychic.net/"; auto request = CreateHttpRequest(URI{uri}, HttpMethod::HTTP_GET, DefaultResponseStreamFactoryMethod); - auto interceptor = Aws::MakeUnique("SmithyInterceptorTest"); + auto interceptor = Aws::MakeUnique(ALLOCATION_TAG); const auto client = MockClient::MakeClient(std::move(interceptor)); - InterceptorContext context{}; - const auto response = client.MakeRequest(request, context); + MockInterceptorRequest modeledRequest{"Take your time"}; + InterceptorContext context{modeledRequest}; + const auto response = client.MakeRequest(context, request); EXPECT_TRUE(response.IsSuccess()); - EXPECT_TRUE(context.GetAttribute("MockInterceptorRequest").IsSuccess()); - EXPECT_TRUE(context.GetAttribute("MockInterceptorResponse").IsSuccess()); + EXPECT_EQ("Called", context.GetAttribute("MockInterceptorRequest")); + EXPECT_EQ("Called", context.GetAttribute("MockInterceptorResponse")); } TEST_F(SmithyInterceptorTest, MockInterceptorShouldReturnFailureRequset) { const auto uri = "https://www.villagepsychic.net/"; auto request = CreateHttpRequest(URI{uri}, HttpMethod::HTTP_GET, DefaultResponseStreamFactoryMethod); - auto interceptor = Aws::MakeUnique("SmithyInterceptorTest"); + auto interceptor = Aws::MakeUnique(ALLOCATION_TAG); const auto client = MockClient::MakeClient(std::move(interceptor)); - InterceptorContext context{}; - const auto response = client.MakeRequest(request, context); + MockInterceptorRequest modeledRequest{"Take your time"}; + InterceptorContext context{modeledRequest}; + const auto response = client.MakeRequest(context, request); EXPECT_FALSE(response.IsSuccess()); - EXPECT_TRUE(context.GetAttribute("MockInterceptorRequest").IsSuccess()); - EXPECT_FALSE(context.GetAttribute("MockInterceptorResponse").IsSuccess()); + EXPECT_EQ("Called", context.GetAttribute("MockInterceptorRequest")); } TEST_F(SmithyInterceptorTest, MockInterceptorShouldReturnFailureReseponse) { const auto uri = "https://www.villagepsychic.net/"; auto request = CreateHttpRequest(URI{uri}, HttpMethod::HTTP_GET, DefaultResponseStreamFactoryMethod); - auto interceptor = Aws::MakeUnique("SmithyInterceptorTest"); + auto interceptor = Aws::MakeUnique(ALLOCATION_TAG); const auto client = MockClient::MakeClient(std::move(interceptor)); - InterceptorContext context{}; - const auto response = client.MakeRequest(request, context); + MockInterceptorRequest modeledRequest{"Take your time"}; + InterceptorContext context{modeledRequest}; + const auto response = client.MakeRequest(context, request); EXPECT_FALSE(response.IsSuccess()); - EXPECT_TRUE(context.GetAttribute("MockInterceptorRequest").IsSuccess()); - EXPECT_TRUE(context.GetAttribute("MockInterceptorResponse").IsSuccess()); + EXPECT_EQ("Called", context.GetAttribute("MockInterceptorRequest")); + EXPECT_EQ("Called", context.GetAttribute("MockInterceptorResponse")); } \ No newline at end of file