Skip to content

DGS-21595 Allow alternate KMS key IDs on a KEK #2018

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import io
import logging
import time
from typing import Optional, Tuple, Any
from typing import Optional, Tuple, Any, List

from tink import aead, daead, KmsClient, kms_client_from_uri, \
register_kms_client, TinkError
Expand Down Expand Up @@ -45,6 +45,7 @@
ENCRYPT_KMS_TYPE = "encrypt.kms.type"
ENCRYPT_DEK_ALGORITHM = "encrypt.dek.algorithm"
ENCRYPT_DEK_EXPIRY_DAYS = "encrypt.dek.expiry.days"
ENCRYPT_ALTERNATE_KMS_KEY_IDS = "encrypt.alternate.kms.key.ids"

MILLIS_IN_DAY = 24 * 60 * 60 * 1000

Expand Down Expand Up @@ -279,7 +280,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek:
raise RuleError(f"no dek found for {dek_id.kek_name} during consume")
encrypted_dek = None
if not kek.shared:
primitive = self._get_aead(self._executor.config, self._kek)
primitive = AeadWrapper(self._executor.config, self._kek)
raw_dek = self._cryptor.generate_key()
encrypted_dek = primitive.encrypt(raw_dek, self._cryptor.EMPTY_AAD)
new_version = dek.version + 1 if is_expired else 1
Expand All @@ -293,7 +294,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek:
key_bytes = dek.get_key_material_bytes()
if key_bytes is None:
if primitive is None:
primitive = self._get_aead(self._executor.config, self._kek)
primitive = AeadWrapper(self._executor.config, self._kek)
encrypted_dek = dek.get_encrypted_key_material_bytes()
raw_dek = primitive.decrypt(encrypted_dek, self._cryptor.EMPTY_AAD)
dek.set_key_material(raw_dek)
Expand Down Expand Up @@ -410,8 +411,51 @@ def _to_object(self, field_type: FieldType, value: bytes) -> Any:
return value
return None

def _get_aead(self, config: dict, kek: Kek) -> aead.Aead:
kek_url = kek.kms_type + "://" + kek.kms_key_id

class AeadWrapper(aead.Aead):
def __init__(self, config: dict, kek: Kek):
self._config = config
self._kek = kek
self._kms_key_ids = self._get_kms_key_ids()

def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
for index, kms_key_id in enumerate(self._kms_key_ids):
try:
aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id)
return aead.encrypt(plaintext, associated_data)
except Exception as e:
log.warning("failed to encrypt with kek %s and kms key id %s",
self._kek.name, kms_key_id)
if index == len(self._kms_key_ids) - 1:
raise RuleError(f"failed to encrypt with all KEKs for {self._kek.name}") from e
raise RuleError("No KEK found for encryption")

def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
for index, kms_key_id in enumerate(self._kms_key_ids):
try:
aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id)
return aead.decrypt(ciphertext, associated_data)
except Exception as e:
log.warning("failed to decrypt with kek %s and kms key id %s",
self._kek.name, kms_key_id)
if index == len(self._kms_key_ids) - 1:
raise RuleError(f"failed to decrypt with all KEKs for {self._kek.name}") from e
raise RuleError("No KEK found for decryption")

def _get_kms_key_ids(self) -> List[str]:
kms_key_ids = [self._kek.kms_key_id]
alternate_kms_key_ids = None
if self._kek.kms_props is not None:
alternate_kms_key_ids = self._kek.kms_props.properties.get(ENCRYPT_ALTERNATE_KMS_KEY_IDS)
if alternate_kms_key_ids is None:
alternate_kms_key_ids = self._config.get(ENCRYPT_ALTERNATE_KMS_KEY_IDS)
if alternate_kms_key_ids is not None:
# Split the comma-separated list of alternate KMS key IDs and append to kms_key_ids
kms_key_ids.extend([id.strip() for id in alternate_kms_key_ids.split(',') if id.strip()])
return kms_key_ids

def _get_aead(self, config: dict, kms_type: str, kms_key_id: str) -> aead.Aead:
kek_url = kms_type + "://" + kms_key_id
kms_client = self._get_kms_client(config, kek_url)
return kms_client.get_aead(kek_url)

Expand Down
62 changes: 62 additions & 0 deletions tests/schema_registry/_async/test_avro_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
async def test_avro_basic_serialization():
conf = {'url': _BASE_URL}
client = AsyncSchemaRegistryClient.new_client(conf)
ser_conf = {'auto.register.schemas': True}

Check failure on line 105 in tests/schema_registry/_async/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_async/test_avro_serdes.py#L105

Define a constant instead of duplicating this literal 'auto.register.schemas' 32 times.
obj = {
'intField': 123,
'doubleField': 45.67,
Expand Down Expand Up @@ -241,7 +241,7 @@
async def test_avro_serialize_references():
conf = {'url': _BASE_URL}
client = AsyncSchemaRegistryClient.new_client(conf)
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}

Check failure on line 244 in tests/schema_registry/_async/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_async/test_avro_serdes.py#L244

Define a constant instead of duplicating this literal 'use.latest.version' 29 times.

referenced = {
'intField': 123,
Expand Down Expand Up @@ -792,7 +792,7 @@
{'name': 'mapField', 'type':
{'type': 'map', 'values': 'string'}
},
{'name': 'unionField', 'type': ['null', 'string'], 'confluent:tags': ['PII']}

Check failure on line 795 in tests/schema_registry/_async/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_async/test_avro_serdes.py#L795

Define a constant instead of duplicating this literal 'confluent:tags' 20 times.
]
}

Expand Down Expand Up @@ -1100,13 +1100,13 @@
"ENCRYPT",
["PII"],
RuleParams({
"encrypt.kek.name": "kek1",

Check failure on line 1103 in tests/schema_registry/_async/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_async/test_avro_serdes.py#L1103

Define a constant instead of duplicating this literal "encrypt.kek.name" 11 times.
"encrypt.kms.type": "local-kms",

Check failure on line 1104 in tests/schema_registry/_async/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_async/test_avro_serdes.py#L1104

Define a constant instead of duplicating this literal "encrypt.kms.type" 11 times.
"encrypt.kms.key.id": "mykey"

Check failure on line 1105 in tests/schema_registry/_async/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_async/test_avro_serdes.py#L1105

Define a constant instead of duplicating this literal "encrypt.kms.key.id" 11 times.
}),
None,
None,
"ERROR,NONE",

Check failure on line 1109 in tests/schema_registry/_async/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_async/test_avro_serdes.py#L1109

Define a constant instead of duplicating this literal "ERROR,NONE" 8 times.
False
)
await client.register_schema(_SUBJECT, Schema(
Expand Down Expand Up @@ -1202,6 +1202,68 @@
assert obj == obj2


async def test_avro_encryption_alternate_keks():
executor = EncryptionExecutor.register_with_clock(FakeClock())

conf = {'url': _BASE_URL}
client = AsyncSchemaRegistryClient.new_client(conf)
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}
rule_conf = {'secret': 'mysecret', 'encrypt.alternate.kms.key.ids': 'mykey2,mykey3'}
schema = {
'type': 'record',
'name': 'test',
'fields': [
{'name': 'intField', 'type': 'int'},
{'name': 'doubleField', 'type': 'double'},
{'name': 'stringField', 'type': 'string', 'confluent:tags': ['PII']},
{'name': 'booleanField', 'type': 'boolean'},
{'name': 'bytesField', 'type': 'bytes', 'confluent:tags': ['PII']},
]
}

rule = Rule(
"test-encrypt",
"",
RuleKind.TRANSFORM,
RuleMode.WRITEREAD,
"ENCRYPT_PAYLOAD",
None,
RuleParams({
"encrypt.kek.name": "kek1",
"encrypt.kms.type": "local-kms",
"encrypt.kms.key.id": "mykey"
}),
None,
None,
"ERROR,NONE",
False
)
await client.register_schema(_SUBJECT, Schema(
json.dumps(schema),
"AVRO",
[],
None,
RuleSet(None, None, [rule])
))

obj = {
'intField': 123,
'doubleField': 45.67,
'stringField': 'hi',
'booleanField': True,
'bytesField': b'foobar',
}
ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf)
dek_client = executor.client
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
obj_bytes = await ser(obj, ser_ctx)

deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf)
executor.client = dek_client
obj2 = await deser(obj_bytes, ser_ctx)
assert obj == obj2


async def test_avro_encryption_deterministic():
executor = FieldEncryptionExecutor.register_with_clock(FakeClock())

Expand Down
62 changes: 62 additions & 0 deletions tests/schema_registry/_sync/test_avro_serdes.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
def test_avro_basic_serialization():
conf = {'url': _BASE_URL}
client = SchemaRegistryClient.new_client(conf)
ser_conf = {'auto.register.schemas': True}

Check failure on line 105 in tests/schema_registry/_sync/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_sync/test_avro_serdes.py#L105

Define a constant instead of duplicating this literal 'auto.register.schemas' 32 times.
obj = {
'intField': 123,
'doubleField': 45.67,
Expand Down Expand Up @@ -241,7 +241,7 @@
def test_avro_serialize_references():
conf = {'url': _BASE_URL}
client = SchemaRegistryClient.new_client(conf)
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}

Check failure on line 244 in tests/schema_registry/_sync/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_sync/test_avro_serdes.py#L244

Define a constant instead of duplicating this literal 'use.latest.version' 29 times.

referenced = {
'intField': 123,
Expand Down Expand Up @@ -792,7 +792,7 @@
{'name': 'mapField', 'type':
{'type': 'map', 'values': 'string'}
},
{'name': 'unionField', 'type': ['null', 'string'], 'confluent:tags': ['PII']}

Check failure on line 795 in tests/schema_registry/_sync/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_sync/test_avro_serdes.py#L795

Define a constant instead of duplicating this literal 'confluent:tags' 20 times.
]
}

Expand Down Expand Up @@ -1100,13 +1100,13 @@
"ENCRYPT",
["PII"],
RuleParams({
"encrypt.kek.name": "kek1",

Check failure on line 1103 in tests/schema_registry/_sync/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_sync/test_avro_serdes.py#L1103

Define a constant instead of duplicating this literal "encrypt.kek.name" 11 times.
"encrypt.kms.type": "local-kms",

Check failure on line 1104 in tests/schema_registry/_sync/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_sync/test_avro_serdes.py#L1104

Define a constant instead of duplicating this literal "encrypt.kms.type" 11 times.
"encrypt.kms.key.id": "mykey"

Check failure on line 1105 in tests/schema_registry/_sync/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_sync/test_avro_serdes.py#L1105

Define a constant instead of duplicating this literal "encrypt.kms.key.id" 11 times.
}),
None,
None,
"ERROR,NONE",

Check failure on line 1109 in tests/schema_registry/_sync/test_avro_serdes.py

View check run for this annotation

SonarQube-Confluent / confluent-kafka-python Sonarqube Results

tests/schema_registry/_sync/test_avro_serdes.py#L1109

Define a constant instead of duplicating this literal "ERROR,NONE" 8 times.
False
)
client.register_schema(_SUBJECT, Schema(
Expand Down Expand Up @@ -1202,6 +1202,68 @@
assert obj == obj2


def test_avro_encryption_alternate_keks():
executor = EncryptionExecutor.register_with_clock(FakeClock())

conf = {'url': _BASE_URL}
client = SchemaRegistryClient.new_client(conf)
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}
rule_conf = {'secret': 'mysecret', 'encrypt.alternate.kms.key.ids': 'mykey2,mykey3'}
schema = {
'type': 'record',
'name': 'test',
'fields': [
{'name': 'intField', 'type': 'int'},
{'name': 'doubleField', 'type': 'double'},
{'name': 'stringField', 'type': 'string', 'confluent:tags': ['PII']},
{'name': 'booleanField', 'type': 'boolean'},
{'name': 'bytesField', 'type': 'bytes', 'confluent:tags': ['PII']},
]
}

rule = Rule(
"test-encrypt",
"",
RuleKind.TRANSFORM,
RuleMode.WRITEREAD,
"ENCRYPT_PAYLOAD",
None,
RuleParams({
"encrypt.kek.name": "kek1",
"encrypt.kms.type": "local-kms",
"encrypt.kms.key.id": "mykey"
}),
None,
None,
"ERROR,NONE",
False
)
client.register_schema(_SUBJECT, Schema(
json.dumps(schema),
"AVRO",
[],
None,
RuleSet(None, None, [rule])
))

obj = {
'intField': 123,
'doubleField': 45.67,
'stringField': 'hi',
'booleanField': True,
'bytesField': b'foobar',
}
ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf)
dek_client = executor.client
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
obj_bytes = ser(obj, ser_ctx)

deser = AvroDeserializer(client, rule_conf=rule_conf)
executor.client = dek_client
obj2 = deser(obj_bytes, ser_ctx)
assert obj == obj2


def test_avro_encryption_deterministic():
executor = FieldEncryptionExecutor.register_with_clock(FakeClock())

Expand Down