Skip to content

Commit 24eeeb0

Browse files
Updating EncryptionKeyFactory to add overirde AWS request configuration in KMS calls made generate encryption key (#3103)
1 parent 5d2d6f5 commit 24eeeb0

File tree

5 files changed

+87
-10
lines changed

5 files changed

+87
-10
lines changed

athena-dynamodb/src/main/java/com/amazonaws/athena/connectors/dynamodb/DynamoDBMetadataHandler.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,8 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
559559
logger.info("QPT Split Requested");
560560
return setupQueryPassthroughSplit(request);
561561
}
562+
FederatedIdentity federatedIdentity = request.getIdentity();
563+
AwsRequestOverrideConfiguration overrideConfig = getRequestOverrideConfig(federatedIdentity.getConfigOptions());
562564

563565
int partitionContd = decodeContinuationToken(request);
564566
Set<Split> splits = new HashSet<>();
@@ -584,7 +586,7 @@ public GetSplitsResponse doGetSplits(BlockAllocator allocator, GetSplitsRequest
584586
Object hashKeyValue = DDBTypeUtils.convertArrowTypeIfNecessary(hashKeyName, hashKeyValueReader.readObject());
585587
splitMetadata.put(hashKeyName, DDBTypeUtils.attributeToJson(DDBTypeUtils.toAttributeValue(hashKeyValue), hashKeyName));
586588

587-
splits.add(new Split(spillLocation, makeEncryptionKey(), splitMetadata));
589+
splits.add(new Split(spillLocation, makeEncryptionKey(overrideConfig), splitMetadata));
588590

589591
if (splits.size() == MAX_SPLITS_PER_REQUEST && curPartition != partitions.getRowCount() - 1) {
590592
// We've reached max page size and this is not the last partition
@@ -609,7 +611,7 @@ else if (SCAN_PARTITION_TYPE.equals(partitionType)) {
609611
splitMetadata.put(SEGMENT_ID_PROPERTY, String.valueOf(curPartition));
610612
splitMetadata.put(SEGMENT_COUNT_METADATA, String.valueOf(segmentCount));
611613

612-
splits.add(new Split(spillLocation, makeEncryptionKey(), splitMetadata));
614+
splits.add(new Split(spillLocation, makeEncryptionKey(overrideConfig), splitMetadata));
613615

614616
if (splits.size() == MAX_SPLITS_PER_REQUEST && curPartition != segmentCount - 1) {
615617
// We've reached max page size and this is not the last partition
@@ -740,13 +742,15 @@ else if (useQueryPlan && filterPredicates.containsKey(rangeKeyName)) {
740742
*/
741743
private GetSplitsResponse setupQueryPassthroughSplit(GetSplitsRequest request)
742744
{
745+
FederatedIdentity federatedIdentity = request.getIdentity();
746+
AwsRequestOverrideConfiguration overrideConfig = getRequestOverrideConfig(federatedIdentity.getConfigOptions());
743747
//Every split must have a unique location if we wish to spill to avoid failures
744748
SpillLocation spillLocation = makeSpillLocation(request);
745749

746750
//Since this is QPT query we return a fixed split.
747751
Map<String, String> qptArguments = request.getConstraints().getQueryPassthroughArguments();
748752
return new GetSplitsResponse(request.getCatalogName(),
749-
Split.newBuilder(spillLocation, makeEncryptionKey())
753+
Split.newBuilder(spillLocation, makeEncryptionKey(overrideConfig))
750754
.applyProperties(qptArguments)
751755
.build());
752756
}

athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/handlers/MetadataHandler.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,11 @@ protected EncryptionKey makeEncryptionKey()
233233
return (encryptionKeyFactory != null) ? encryptionKeyFactory.create() : null;
234234
}
235235

236+
protected EncryptionKey makeEncryptionKey(AwsRequestOverrideConfiguration awsRequestOverrideConfiguration)
237+
{
238+
return (encryptionKeyFactory != null) ? encryptionKeyFactory.create(awsRequestOverrideConfiguration) : null;
239+
}
240+
236241
/**
237242
* Used to make a spill location for a split. Each split should have a unique spill location, so be sure
238243
* to call this method once per split!

athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/EncryptionKeyFactory.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
* #L%
2121
*/
2222

23+
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
24+
2325
/**
2426
* Defines a factory that can be used to create AES-GCM compatible encryption keys.
2527
*/
@@ -29,4 +31,9 @@ public interface EncryptionKeyFactory
2931
* @return A key that satisfies the specification defined in BlockCrypto
3032
*/
3133
EncryptionKey create();
34+
35+
default EncryptionKey create(AwsRequestOverrideConfiguration awsRequestOverrideConfiguration)
36+
{
37+
return create();
38+
}
3239
}

athena-federation-sdk/src/main/java/com/amazonaws/athena/connector/lambda/security/KmsKeyFactory.java

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@
2121
*/
2222

2323
import com.amazonaws.athena.connector.lambda.exceptions.AthenaConnectorException;
24+
import org.slf4j.Logger;
25+
import org.slf4j.LoggerFactory;
26+
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
2427
import software.amazon.awssdk.services.glue.model.ErrorDetails;
2528
import software.amazon.awssdk.services.glue.model.FederationSourceErrorCode;
2629
import software.amazon.awssdk.services.kms.KmsClient;
@@ -39,6 +42,7 @@
3942
public class KmsKeyFactory
4043
implements EncryptionKeyFactory
4144
{
45+
private static final Logger logger = LoggerFactory.getLogger(KmsKeyFactory.class);
4246
private final KmsClient kmsClient;
4347
private final String masterKeyId;
4448

@@ -53,21 +57,30 @@ public KmsKeyFactory(KmsClient kmsClient, String masterKeyId)
5357
*/
5458
public EncryptionKey create()
5559
{
60+
return create(null);
61+
}
62+
63+
@Override
64+
public EncryptionKey create(AwsRequestOverrideConfiguration awsRequestOverrideConfiguration)
65+
{
66+
GenerateDataKeyRequest.Builder dataKeyBuilder = GenerateDataKeyRequest.builder()
67+
.keyId(masterKeyId)
68+
.keySpec(DataKeySpec.AES_128);
69+
if (awsRequestOverrideConfiguration != null) {
70+
logger.info("Using AWS KMS Request Override Configuration:");
71+
dataKeyBuilder.overrideConfiguration(awsRequestOverrideConfiguration);
72+
}
73+
5674
GenerateDataKeyResponse dataKeyResponse;
5775
try {
58-
dataKeyResponse = kmsClient.generateDataKey(
59-
GenerateDataKeyRequest.builder()
60-
.keyId(masterKeyId)
61-
.keySpec(DataKeySpec.AES_128)
62-
.build());
76+
dataKeyResponse = kmsClient.generateDataKey(dataKeyBuilder.build());
6377
}
6478
catch (NotFoundException e) {
6579
throw new AthenaConnectorException(e.getMessage(), ErrorDetails.builder().errorCode(FederationSourceErrorCode.ENTITY_NOT_FOUND_EXCEPTION.toString()).build());
6680
}
6781

6882
GenerateRandomRequest randomRequest = GenerateRandomRequest.builder()
69-
.numberOfBytes(AesGcmBlockCrypto.NONCE_BYTES)
70-
.build();
83+
.numberOfBytes(AesGcmBlockCrypto.NONCE_BYTES).build();
7184
GenerateRandomResponse randomResponse = kmsClient.generateRandom(randomRequest);
7285

7386
return new EncryptionKey(dataKeyResponse.plaintext().asByteArray(), randomResponse.plaintext().asByteArray());

athena-federation-sdk/src/test/java/com/amazonaws/athena/connector/lambda/security/KmsKeyFactoryTest.java

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import org.mockito.Mockito;
3333
import org.mockito.MockitoAnnotations;
3434

35+
import software.amazon.awssdk.awscore.AwsRequestOverrideConfiguration;
3536
import software.amazon.awssdk.core.SdkBytes;
3637
import software.amazon.awssdk.services.kms.KmsClient;
3738
import software.amazon.awssdk.services.kms.model.GenerateDataKeyRequest;
@@ -95,4 +96,51 @@ public void testNotFoundException() {
9596
kmsKeyFactory.create();
9697
});
9798
}
99+
100+
@Test
101+
public void testCreateWithOverrideConfiguration() {
102+
byte[] testPlaintextKey = new byte[] { 1, 2, 3, 4, 5 };
103+
byte[] testNonce = new byte[] { 9, 8, 7 };
104+
AwsRequestOverrideConfiguration overrideConfig = AwsRequestOverrideConfiguration.builder().build();
105+
106+
GenerateDataKeyResponse dataKeyResponse = GenerateDataKeyResponse.builder()
107+
.plaintext(SdkBytes.fromByteArray(testPlaintextKey))
108+
.build();
109+
110+
GenerateRandomResponse randomResponse = GenerateRandomResponse.builder()
111+
.plaintext(SdkBytes.fromByteArray(testNonce))
112+
.build();
113+
114+
when(mockKmsClient.generateDataKey((GenerateDataKeyRequest) any())).thenReturn(dataKeyResponse);
115+
when(mockKmsClient.generateRandom((GenerateRandomRequest) any())).thenReturn(randomResponse);
116+
117+
EncryptionKey result = kmsKeyFactory.create(overrideConfig);
118+
119+
assertArrayEquals(testPlaintextKey, result.getKey());
120+
assertArrayEquals(testNonce, result.getNonce());
121+
verify(mockKmsClient).generateDataKey((GenerateDataKeyRequest) any());
122+
verify(mockKmsClient).generateRandom((GenerateRandomRequest) any());
123+
}
124+
125+
@Test
126+
public void testCreateWithNullOverrideConfiguration() {
127+
byte[] testPlaintextKey = new byte[] { 1, 2, 3, 4, 5 };
128+
byte[] testNonce = new byte[] { 9, 8, 7 };
129+
130+
GenerateDataKeyResponse dataKeyResponse = GenerateDataKeyResponse.builder()
131+
.plaintext(SdkBytes.fromByteArray(testPlaintextKey))
132+
.build();
133+
134+
GenerateRandomResponse randomResponse = GenerateRandomResponse.builder()
135+
.plaintext(SdkBytes.fromByteArray(testNonce))
136+
.build();
137+
138+
when(mockKmsClient.generateDataKey((GenerateDataKeyRequest) any())).thenReturn(dataKeyResponse);
139+
when(mockKmsClient.generateRandom((GenerateRandomRequest) any())).thenReturn(randomResponse);
140+
141+
EncryptionKey result = kmsKeyFactory.create(null);
142+
143+
assertArrayEquals(testPlaintextKey, result.getKey());
144+
assertArrayEquals(testNonce, result.getNonce());
145+
}
98146
}

0 commit comments

Comments
 (0)