From 431f3d5233810e5e51e93c014608c9a5c0d74bfe Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Mon, 4 Aug 2025 16:57:04 -0700 Subject: [PATCH] First cut --- .../rules/encryption/encrypt_executor.py | 54 ++++++++++++++-- .../_async/test_avro_serdes.py | 62 +++++++++++++++++++ .../schema_registry/_sync/test_avro_serdes.py | 62 +++++++++++++++++++ 3 files changed, 173 insertions(+), 5 deletions(-) diff --git a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py index 8a05e3077..45a4ab99a 100644 --- a/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py +++ b/src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py @@ -16,7 +16,7 @@ import io import logging import time -from typing import Optional, Tuple, Any +from typing import Optional, Tuple, Any, List from tink import aead, daead, KmsClient, kms_client_from_uri, \ register_kms_client, TinkError @@ -45,6 +45,7 @@ ENCRYPT_KMS_TYPE = "encrypt.kms.type" ENCRYPT_DEK_ALGORITHM = "encrypt.dek.algorithm" ENCRYPT_DEK_EXPIRY_DAYS = "encrypt.dek.expiry.days" +ENCRYPT_ALTERNATE_KMS_KEY_IDS = "encrypt.alternate.kms.key.ids" MILLIS_IN_DAY = 24 * 60 * 60 * 1000 @@ -279,7 +280,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: raise RuleError(f"no dek found for {dek_id.kek_name} during consume") encrypted_dek = None if not kek.shared: - primitive = self._get_aead(self._executor.config, self._kek) + primitive = AeadWrapper(self._executor.config, self._kek) raw_dek = self._cryptor.generate_key() encrypted_dek = primitive.encrypt(raw_dek, self._cryptor.EMPTY_AAD) new_version = dek.version + 1 if is_expired else 1 @@ -293,7 +294,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek: key_bytes = dek.get_key_material_bytes() if key_bytes is None: if primitive is None: - primitive = self._get_aead(self._executor.config, self._kek) + primitive = AeadWrapper(self._executor.config, self._kek) encrypted_dek = dek.get_encrypted_key_material_bytes() raw_dek = primitive.decrypt(encrypted_dek, self._cryptor.EMPTY_AAD) dek.set_key_material(raw_dek) @@ -410,8 +411,51 @@ def _to_object(self, field_type: FieldType, value: bytes) -> Any: return value return None - def _get_aead(self, config: dict, kek: Kek) -> aead.Aead: - kek_url = kek.kms_type + "://" + kek.kms_key_id + +class AeadWrapper(aead.Aead): + def __init__(self, config: dict, kek: Kek): + self._config = config + self._kek = kek + self._kms_key_ids = self._get_kms_key_ids() + + def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes: + for index, kms_key_id in enumerate(self._kms_key_ids): + try: + aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) + return aead.encrypt(plaintext, associated_data) + except Exception as e: + log.warning("failed to encrypt with kek %s and kms key id %s", + self._kek.name, kms_key_id) + if index == len(self._kms_key_ids) - 1: + raise RuleError(f"failed to encrypt with all KEKs for {self._kek.name}") from e + raise RuleError("No KEK found for encryption") + + def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes: + for index, kms_key_id in enumerate(self._kms_key_ids): + try: + aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id) + return aead.decrypt(ciphertext, associated_data) + except Exception as e: + log.warning("failed to decrypt with kek %s and kms key id %s", + self._kek.name, kms_key_id) + if index == len(self._kms_key_ids) - 1: + raise RuleError(f"failed to decrypt with all KEKs for {self._kek.name}") from e + raise RuleError("No KEK found for decryption") + + def _get_kms_key_ids(self) -> List[str]: + kms_key_ids = [self._kek.kms_key_id] + alternate_kms_key_ids = None + if self._kek.kms_props is not None: + alternate_kms_key_ids = self._kek.kms_props.properties.get(ENCRYPT_ALTERNATE_KMS_KEY_IDS) + if alternate_kms_key_ids is None: + alternate_kms_key_ids = self._config.get(ENCRYPT_ALTERNATE_KMS_KEY_IDS) + if alternate_kms_key_ids is not None: + # Split the comma-separated list of alternate KMS key IDs and append to kms_key_ids + kms_key_ids.extend([id.strip() for id in alternate_kms_key_ids.split(',') if id.strip()]) + return kms_key_ids + + def _get_aead(self, config: dict, kms_type: str, kms_key_id: str) -> aead.Aead: + kek_url = kms_type + "://" + kms_key_id kms_client = self._get_kms_client(config, kek_url) return kms_client.get_aead(kek_url) diff --git a/tests/schema_registry/_async/test_avro_serdes.py b/tests/schema_registry/_async/test_avro_serdes.py index bb95f4098..973ae8889 100644 --- a/tests/schema_registry/_async/test_avro_serdes.py +++ b/tests/schema_registry/_async/test_avro_serdes.py @@ -1202,6 +1202,68 @@ async def test_avro_payload_encryption(): assert obj == obj2 +async def test_avro_encryption_alternate_keks(): + executor = EncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = AsyncSchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret', 'encrypt.alternate.kms.key.ids': 'mykey2,mykey3'} + schema = { + 'type': 'record', + 'name': 'test', + 'fields': [ + {'name': 'intField', 'type': 'int'}, + {'name': 'doubleField', 'type': 'double'}, + {'name': 'stringField', 'type': 'string', 'confluent:tags': ['PII']}, + {'name': 'booleanField', 'type': 'boolean'}, + {'name': 'bytesField', 'type': 'bytes', 'confluent:tags': ['PII']}, + ] + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT_PAYLOAD", + None, + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + await client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "AVRO", + [], + None, + RuleSet(None, None, [rule]) + )) + + obj = { + 'intField': 123, + 'doubleField': 45.67, + 'stringField': 'hi', + 'booleanField': True, + 'bytesField': b'foobar', + } + ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = await ser(obj, ser_ctx) + + deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf) + executor.client = dek_client + obj2 = await deser(obj_bytes, ser_ctx) + assert obj == obj2 + + async def test_avro_encryption_deterministic(): executor = FieldEncryptionExecutor.register_with_clock(FakeClock()) diff --git a/tests/schema_registry/_sync/test_avro_serdes.py b/tests/schema_registry/_sync/test_avro_serdes.py index 735d0f29a..9cb6477dc 100644 --- a/tests/schema_registry/_sync/test_avro_serdes.py +++ b/tests/schema_registry/_sync/test_avro_serdes.py @@ -1202,6 +1202,68 @@ def test_avro_payload_encryption(): assert obj == obj2 +def test_avro_encryption_alternate_keks(): + executor = EncryptionExecutor.register_with_clock(FakeClock()) + + conf = {'url': _BASE_URL} + client = SchemaRegistryClient.new_client(conf) + ser_conf = {'auto.register.schemas': False, 'use.latest.version': True} + rule_conf = {'secret': 'mysecret', 'encrypt.alternate.kms.key.ids': 'mykey2,mykey3'} + schema = { + 'type': 'record', + 'name': 'test', + 'fields': [ + {'name': 'intField', 'type': 'int'}, + {'name': 'doubleField', 'type': 'double'}, + {'name': 'stringField', 'type': 'string', 'confluent:tags': ['PII']}, + {'name': 'booleanField', 'type': 'boolean'}, + {'name': 'bytesField', 'type': 'bytes', 'confluent:tags': ['PII']}, + ] + } + + rule = Rule( + "test-encrypt", + "", + RuleKind.TRANSFORM, + RuleMode.WRITEREAD, + "ENCRYPT_PAYLOAD", + None, + RuleParams({ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey" + }), + None, + None, + "ERROR,NONE", + False + ) + client.register_schema(_SUBJECT, Schema( + json.dumps(schema), + "AVRO", + [], + None, + RuleSet(None, None, [rule]) + )) + + obj = { + 'intField': 123, + 'doubleField': 45.67, + 'stringField': 'hi', + 'booleanField': True, + 'bytesField': b'foobar', + } + ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf) + dek_client = executor.client + ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE) + obj_bytes = ser(obj, ser_ctx) + + deser = AvroDeserializer(client, rule_conf=rule_conf) + executor.client = dek_client + obj2 = deser(obj_bytes, ser_ctx) + assert obj == obj2 + + def test_avro_encryption_deterministic(): executor = FieldEncryptionExecutor.register_with_clock(FakeClock())