Skip to content

Commit 4393c28

Browse files
committed
pipe down request to sigv4
override request down provider chain used custom env declaration for credential track test
1 parent da98492 commit 4393c28

File tree

6 files changed

+102
-26
lines changed

6 files changed

+102
-26
lines changed

src/aws-cpp-sdk-core/include/aws/core/auth/AWSCredentialsProviderChain.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ namespace Aws
2929
*/
3030
virtual AWSCredentials GetAWSCredentials();
3131

32+
/**
33+
* When a credentials provider in the chain returns empty credentials,
34+
* We go on to the next provider until we have either exhausted the installed providers in the chain or something returns non-empty credentials.
35+
* This overload passes the request to providers for user agent feature tracking.
36+
*/
37+
virtual AWSCredentials GetAWSCredentials(Aws::AmazonWebServiceRequest& request);
38+
3239
/**
3340
* Gets all providers stored in this chain.
3441
*/

src/aws-cpp-sdk-core/include/aws/core/auth/signer/AWSAuthV4Signer.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace smithy
2727

2828
namespace Aws
2929
{
30+
class AmazonWebServiceRequest;
31+
3032
namespace Http
3133
{
3234
class HttpRequest;
@@ -142,6 +144,13 @@ namespace Aws
142144
*/
143145
bool SignRequest(Aws::Http::HttpRequest& request, const char* region, const char* serviceName, bool signBody) const override;
144146

147+
/**
148+
* Uses AWS Auth V4 signing method with SHA256 HMAC algorithm. If signBody is false
149+
* and https is being used then the body of the payload will not be signed.
150+
* This overload passes the AWS request to the credentials provider for user agent feature tracking.
151+
*/
152+
bool SignRequest(Aws::Http::HttpRequest& request, Aws::AmazonWebServiceRequest& awsRequest, const char* region, const char* serviceName, bool signBody) const;
153+
145154
/**
146155
* Takes a request and signs the URI based on the HttpMethod, URI and other info from the request.
147156
* the region the signer was initialized with will be used for the signature.
@@ -183,6 +192,8 @@ namespace Aws
183192

184193
virtual Aws::Auth::AWSCredentials GetCredentials(const std::shared_ptr<Aws::Http::ServiceSpecificParameters> &serviceSpecificParameters) const;
185194

195+
virtual Aws::Auth::AWSCredentials GetCredentials(Aws::AmazonWebServiceRequest& awsRequest, const std::shared_ptr<Aws::Http::ServiceSpecificParameters> &serviceSpecificParameters) const;
196+
186197
Aws::String GetServiceName() const { return m_serviceName; }
187198
Aws::String GetRegion() const { return m_region; }
188199
Aws::String GenerateSignature(const Aws::Auth::AWSCredentials& credentials,

src/aws-cpp-sdk-core/source/auth/AWSCredentialsProviderChain.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,29 @@ AWSCredentials AWSCredentialsProviderChain::GetAWSCredentials()
3333
{
3434
AWSCredentials credentials = credentialsProvider->GetAWSCredentials();
3535
if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty())
36+
{
37+
m_cachedProvider = credentialsProvider;
38+
return credentials;
39+
}
40+
}
41+
return AWSCredentials();
42+
}
43+
44+
AWSCredentials AWSCredentialsProviderChain::GetAWSCredentials(Aws::AmazonWebServiceRequest& request)
45+
{
46+
ReaderLockGuard lock(m_cachedProviderLock);
47+
if (m_cachedProvider) {
48+
AWSCredentials credentials = m_cachedProvider->GetAWSCredentials(request);
49+
if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty())
50+
{
51+
return credentials;
52+
}
53+
}
54+
lock.UpgradeToWriterLock();
55+
for (auto&& credentialsProvider : m_providerChain)
56+
{
57+
AWSCredentials credentials = credentialsProvider->GetAWSCredentials(request);
58+
if (!credentials.GetAWSAccessKeyId().empty() && !credentials.GetAWSSecretKey().empty())
3659
{
3760
// TODO: issue of only chain, not overidden
3861
// which credentials were used -- add it somethow

src/aws-cpp-sdk-core/source/auth/signer/AWSAuthV4Signer.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,12 @@ bool AWSAuthV4Signer::SignRequest(Aws::Http::HttpRequest& request, const char* r
339339
return SignRequestWithCreds(request, credentials, region, serviceName, signBody);
340340
}
341341

342+
bool AWSAuthV4Signer::SignRequest(Aws::Http::HttpRequest& request, Aws::AmazonWebServiceRequest& awsRequest, const char* region, const char* serviceName, bool signBody) const
343+
{
344+
AWSCredentials credentials = GetCredentials(awsRequest, request.GetServiceSpecificParameters());
345+
return SignRequestWithCreds(request, credentials, region, serviceName, signBody);
346+
}
347+
342348
bool AWSAuthV4Signer::PresignRequest(Aws::Http::HttpRequest& request, long long expirationTimeInSeconds) const
343349
{
344350
return PresignRequest(request, m_region.c_str(), expirationTimeInSeconds);
@@ -595,3 +601,8 @@ Aws::Auth::AWSCredentials AWSAuthV4Signer::GetCredentials(const std::shared_ptr<
595601
AWS_UNREFERENCED_PARAM(serviceSpecificParameters);
596602
return m_credentialsProvider->GetAWSCredentials();
597603
}
604+
605+
Aws::Auth::AWSCredentials AWSAuthV4Signer::GetCredentials(Aws::AmazonWebServiceRequest& awsRequest, const std::shared_ptr<Aws::Http::ServiceSpecificParameters> &serviceSpecificParameters) const {
606+
AWS_UNREFERENCED_PARAM(serviceSpecificParameters);
607+
return m_credentialsProvider->GetAWSCredentials(awsRequest);
608+
}

src/aws-cpp-sdk-core/source/client/AWSClient.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <aws/core/AmazonWebServiceRequest.h>
88
#include <aws/core/auth/AWSAuthSigner.h>
99
#include <aws/core/auth/AWSAuthSignerProvider.h>
10+
#include <aws/core/auth/signer/AWSAuthV4Signer.h>
1011
#include <aws/core/client/AWSUrlPresigner.h>
1112
#include <aws/core/client/AWSError.h>
1213
#include <aws/core/client/AWSErrorMarshaller.h>
@@ -579,6 +580,11 @@ HttpResponseOutcome AWSClient::AttemptOneRequest(const std::shared_ptr<Aws::Http
579580

580581
auto signer = GetSignerByName(signerName);
581582
auto signedRequest = TracingUtils::MakeCallWithTiming<bool>([&]() -> bool {
583+
// Use request-aware signing for credential tracking
584+
auto* v4Signer = dynamic_cast<AWSAuthV4Signer*>(signer);
585+
if (v4Signer) {
586+
return v4Signer->SignRequest(*httpRequest, const_cast<Aws::AmazonWebServiceRequest&>(request), signerRegionOverride, signerServiceNameOverride, true);
587+
}
582588
return signer->SignRequest(*httpRequest, signerRegionOverride, signerServiceNameOverride, true);
583589
},
584590
TracingUtils::SMITHY_CLIENT_SIGNING_METRIC,
@@ -590,12 +596,6 @@ HttpResponseOutcome AWSClient::AttemptOneRequest(const std::shared_ptr<Aws::Http
590596
return HttpResponseOutcome(AWSError<CoreErrors>(CoreErrors::CLIENT_SIGNING_FAILURE, "", "SDK failed to sign the request", false/*retryable*/));
591597
}
592598

593-
// Track credential provider usage for User-Agent features
594-
auto credentialsProvider = GetCredentialsProvider();
595-
if (credentialsProvider) {
596-
credentialsProvider->GetAWSCredentials(const_cast<Aws::AmazonWebServiceRequest&>(request));
597-
}
598-
599599
if (request.GetRequestSignedHandler())
600600
{
601601
request.GetRequestSignedHandler()(*httpRequest);

tests/aws-cpp-sdk-core-tests/aws/auth/CredentialTrackingTest.cpp

Lines changed: 44 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,45 @@
77
#include <aws/testing/AwsTestHelpers.h>
88
#include <aws/testing/mocks/aws/client/MockAWSClient.h>
99
#include <aws/testing/mocks/http/MockHttpClient.h>
10+
#include <aws/testing/platform/PlatformTesting.h>
1011
#include <aws/core/auth/AWSCredentialsProvider.h>
1112
#include <aws/core/client/AWSClient.h>
1213
#include <aws/core/utils/StringUtils.h>
13-
#include <aws/core/platform/Environment.h>
1414

1515
using namespace Aws::Client;
1616
using namespace Aws::Auth;
1717
using namespace Aws::Http;
1818

1919
static const char ALLOCATION_TAG[] = "CredentialTrackingTest";
2020

21+
// Custom client that uses environment credential provider for testing
22+
class CredentialTestingClient : public Aws::Client::AWSClient
23+
{
24+
public:
25+
explicit CredentialTestingClient(const Aws::Client::ClientConfiguration& configuration)
26+
: AWSClient(configuration,
27+
Aws::MakeShared<Aws::Client::AWSAuthV4Signer>(ALLOCATION_TAG,
28+
Aws::MakeShared<EnvironmentAWSCredentialsProvider>(ALLOCATION_TAG),
29+
"service", configuration.region),
30+
Aws::MakeShared<MockAWSErrorMarshaller>(ALLOCATION_TAG))
31+
{
32+
}
33+
34+
Aws::Client::HttpResponseOutcome MakeRequest(const Aws::AmazonWebServiceRequest& request)
35+
{
36+
auto uri = Aws::Http::URI("https://test.com");
37+
return AWSClient::AttemptExhaustively(uri, request, Aws::Http::HttpMethod::HTTP_POST, Aws::Auth::SIGV4_SIGNER);
38+
}
39+
40+
const char* GetServiceClientName() const override { return "CredentialTestingClient"; }
2141

42+
protected:
43+
Aws::Client::AWSError<Aws::Client::CoreErrors> BuildAWSError(const std::shared_ptr<Aws::Http::HttpResponse>& response) const override
44+
{
45+
AWS_UNREFERENCED_PARAM(response);
46+
return Aws::Client::AWSError<Aws::Client::CoreErrors>(Aws::Client::CoreErrors::UNKNOWN, false);
47+
}
48+
};
2249

2350
class CredentialTrackingTest : public Aws::Testing::AwsCppSdkGTestSuite
2451
{
@@ -46,12 +73,14 @@ class CredentialTrackingTest : public Aws::Testing::AwsCppSdkGTestSuite
4673

4774
TEST_F(CredentialTrackingTest, TestEnvironmentCredentialsTracking)
4875
{
49-
Aws::Environment::SetEnv("AWS_ACCESS_KEY_ID", "test-access-key", 1);
50-
Aws::Environment::SetEnv("AWS_SECRET_ACCESS_KEY", "test-secret-key", 1);
76+
Aws::Environment::EnvironmentRAII testEnvironment{{
77+
{"AWS_ACCESS_KEY_ID", "test-access-key"},
78+
{"AWS_SECRET_ACCESS_KEY", "test-secret-key"},
79+
}};
5180

5281
// Setup mock response
5382
std::shared_ptr<HttpRequest> requestTmp =
54-
CreateHttpRequest(Aws::Http::URI("dummy"), Aws::Http::HttpMethod::HTTP_POST,
83+
CreateHttpRequest(Aws::Http::URI("dummy"), Aws::Http::HttpMethod::HTTP_POST,
5584
Aws::Utils::Stream::DefaultResponseStreamFactoryMethod);
5685
auto successResponse = Aws::MakeShared<Standard::StandardHttpResponse>(ALLOCATION_TAG, requestTmp);
5786
successResponse->SetResponseCode(HttpResponseCode::OK);
@@ -64,22 +93,20 @@ TEST_F(CredentialTrackingTest, TestEnvironmentCredentialsTracking)
6493
Aws::Client::ClientConfiguration clientConfig(cfgInit);
6594
clientConfig.region = Aws::Region::US_EAST_1;
6695

67-
// Create client with environment credentials signer
68-
Aws::Client::AWSClient client(clientConfig,
69-
Aws::MakeShared<Aws::Client::AWSAuthV4Signer>(ALLOCATION_TAG,
70-
Aws::MakeShared<EnvironmentAWSCredentialsProvider>(ALLOCATION_TAG),
71-
"service", clientConfig.region),
72-
Aws::MakeShared<MockAWSErrorMarshaller>(ALLOCATION_TAG));
73-
96+
// Create credential testing client that uses default provider chain
97+
CredentialTestingClient client(clientConfig);
98+
99+
// Create mock request
74100
AmazonWebServiceRequestMock mockRequest;
75-
auto outcome = client.AttemptExhaustively(Aws::Http::URI("https://test.com"), mockRequest,
76-
Aws::Http::HttpMethod::HTTP_POST, Aws::Auth::SIGV4_SIGNER);
77-
AWS_ASSERT_SUCCESS(outcome);
101+
102+
// Make request
103+
auto outcome = client.MakeRequest(mockRequest);
104+
ASSERT_TRUE(outcome.IsSuccess());
78105

79106
// Verify User-Agent contains environment credentials tracking
80107
auto lastRequest = mockHttpClient->GetMostRecentHttpRequest();
81-
EXPECT_TRUE(lastRequest.HasUserAgent());
82-
const auto& userAgent = lastRequest.GetUserAgent();
108+
EXPECT_TRUE(lastRequest.HasHeader(Aws::Http::USER_AGENT_HEADER));
109+
const auto& userAgent = lastRequest.GetHeaderValue(Aws::Http::USER_AGENT_HEADER);
83110
EXPECT_FALSE(userAgent.empty());
84111

85112
const auto userAgentParsed = Aws::Utils::StringUtils::Split(userAgent, ' ');
@@ -89,7 +116,4 @@ TEST_F(CredentialTrackingTest, TestEnvironmentCredentialsTracking)
89116
[](const Aws::String& value) { return value.find("m/") != Aws::String::npos && value.find("g") != Aws::String::npos; });
90117

91118
EXPECT_TRUE(businessMetrics != userAgentParsed.end());
92-
93-
Aws::Environment::UnSetEnv("AWS_ACCESS_KEY_ID");
94-
Aws::Environment::UnSetEnv("AWS_SECRET_ACCESS_KEY");
95-
}
119+
}

0 commit comments

Comments
 (0)