Skip to content

Commit 4ac37c3

Browse files
committed
First cut
1 parent 765795c commit 4ac37c3

File tree

3 files changed

+173
-5
lines changed

3 files changed

+173
-5
lines changed

src/confluent_kafka/schema_registry/rules/encryption/encrypt_executor.py

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import io
1717
import logging
1818
import time
19-
from typing import Optional, Tuple, Any
19+
from typing import Optional, Tuple, Any, List
2020

2121
from tink import aead, daead, KmsClient, kms_client_from_uri, \
2222
register_kms_client, TinkError
@@ -45,6 +45,7 @@
4545
ENCRYPT_KMS_TYPE = "encrypt.kms.type"
4646
ENCRYPT_DEK_ALGORITHM = "encrypt.dek.algorithm"
4747
ENCRYPT_DEK_EXPIRY_DAYS = "encrypt.dek.expiry.days"
48+
ENCRYPT_ALTERNATE_KMS_KEY_IDS = "encrypt.alternate.kms.key.ids"
4849

4950
MILLIS_IN_DAY = 24 * 60 * 60 * 1000
5051

@@ -279,7 +280,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek:
279280
raise RuleError(f"no dek found for {dek_id.kek_name} during consume")
280281
encrypted_dek = None
281282
if not kek.shared:
282-
primitive = self._get_aead(self._executor.config, self._kek)
283+
primitive = AeadWrapper(self._executor.config, self._kek)
283284
raw_dek = self._cryptor.generate_key()
284285
encrypted_dek = primitive.encrypt(raw_dek, self._cryptor.EMPTY_AAD)
285286
new_version = dek.version + 1 if is_expired else 1
@@ -293,7 +294,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek:
293294
key_bytes = dek.get_key_material_bytes()
294295
if key_bytes is None:
295296
if primitive is None:
296-
primitive = self._get_aead(self._executor.config, self._kek)
297+
primitive = AeadWrapper(self._executor.config, self._kek)
297298
encrypted_dek = dek.get_encrypted_key_material_bytes()
298299
raw_dek = primitive.decrypt(encrypted_dek, self._cryptor.EMPTY_AAD)
299300
dek.set_key_material(raw_dek)
@@ -410,8 +411,51 @@ def _to_object(self, field_type: FieldType, value: bytes) -> Any:
410411
return value
411412
return None
412413

413-
def _get_aead(self, config: dict, kek: Kek) -> aead.Aead:
414-
kek_url = kek.kms_type + "://" + kek.kms_key_id
414+
415+
class AeadWrapper(aead.Aead):
416+
def __init__(self, config: dict, kek: Kek):
417+
self._config = config
418+
self._kek = kek
419+
self._kms_key_ids = self._get_kms_key_ids()
420+
421+
def encrypt(self, plaintext: bytes, associated_data: bytes) -> bytes:
422+
for index, kms_key_id in enumerate(self._kms_key_ids):
423+
try:
424+
aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id)
425+
return aead.encrypt(plaintext, associated_data)
426+
except Exception as e:
427+
log.warning("failed to encrypt with kek %s and kms key id %s",
428+
self._kek.name, kms_key_id)
429+
if index == len(self._kms_key_ids) - 1:
430+
raise RuleError(f"failed to encrypt with all KEKs for {self._kek.name}") from e
431+
raise RuleError("No KEK found for encryption")
432+
433+
def decrypt(self, ciphertext: bytes, associated_data: bytes) -> bytes:
434+
for index, kms_key_id in enumerate(self._kms_key_ids):
435+
try:
436+
aead = self._get_aead(self._config, self._kek.kms_type, kms_key_id)
437+
return aead.decrypt(ciphertext, associated_data)
438+
except Exception as e:
439+
log.warning("failed to decrypt with kek %s and kms key id %s",
440+
self._kek.name, kms_key_id)
441+
if index == len(self._kms_key_ids) - 1:
442+
raise RuleError(f"failed to decrypt with all KEKs for {self._kek.name}") from e
443+
raise RuleError("No KEK found for decryption")
444+
445+
def _get_kms_key_ids(self) -> List[str]:
446+
kms_key_ids = [self._kek.kms_key_id]
447+
alternate_kms_key_ids = None
448+
if self._kek.kms_props is not None:
449+
alternate_kms_key_ids = self._kek.kms_props.properties.get(ENCRYPT_ALTERNATE_KMS_KEY_IDS)
450+
if alternate_kms_key_ids is None:
451+
alternate_kms_key_ids = self._config.get(ENCRYPT_ALTERNATE_KMS_KEY_IDS)
452+
if alternate_kms_key_ids is not None:
453+
# Split the comma-separated list of alternate KMS key IDs and append to kms_key_ids
454+
kms_key_ids.extend([id.strip() for id in alternate_kms_key_ids.split(',') if id.strip()])
455+
return kms_key_ids
456+
457+
def _get_aead(self, config: dict, kms_type: str, kms_key_id: str) -> aead.Aead:
458+
kek_url = kms_type + "://" + kms_key_id
415459
kms_client = self._get_kms_client(config, kek_url)
416460
return kms_client.get_aead(kek_url)
417461

tests/schema_registry/_async/test_avro_serdes.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,68 @@ async def test_avro_payload_encryption():
12021202
assert obj == obj2
12031203

12041204

1205+
async def test_avro_encryption_alternate_keks():
1206+
executor = EncryptionExecutor.register_with_clock(FakeClock())
1207+
1208+
conf = {'url': _BASE_URL}
1209+
client = AsyncSchemaRegistryClient.new_client(conf)
1210+
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}
1211+
rule_conf = {'secret': 'mysecret', 'encrypt.alternate.kms.key.ids': 'mykey2,mykey3'}
1212+
schema = {
1213+
'type': 'record',
1214+
'name': 'test',
1215+
'fields': [
1216+
{'name': 'intField', 'type': 'int'},
1217+
{'name': 'doubleField', 'type': 'double'},
1218+
{'name': 'stringField', 'type': 'string', 'confluent:tags': ['PII']},
1219+
{'name': 'booleanField', 'type': 'boolean'},
1220+
{'name': 'bytesField', 'type': 'bytes', 'confluent:tags': ['PII']},
1221+
]
1222+
}
1223+
1224+
rule = Rule(
1225+
"test-encrypt",
1226+
"",
1227+
RuleKind.TRANSFORM,
1228+
RuleMode.WRITEREAD,
1229+
"ENCRYPT_PAYLOAD",
1230+
None,
1231+
RuleParams({
1232+
"encrypt.kek.name": "kek1",
1233+
"encrypt.kms.type": "local-kms",
1234+
"encrypt.kms.key.id": "mykey"
1235+
}),
1236+
None,
1237+
None,
1238+
"ERROR,NONE",
1239+
False
1240+
)
1241+
await client.register_schema(_SUBJECT, Schema(
1242+
json.dumps(schema),
1243+
"AVRO",
1244+
[],
1245+
None,
1246+
RuleSet(None, None, [rule])
1247+
))
1248+
1249+
obj = {
1250+
'intField': 123,
1251+
'doubleField': 45.67,
1252+
'stringField': 'hi',
1253+
'booleanField': True,
1254+
'bytesField': b'foobar',
1255+
}
1256+
ser = await AsyncAvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf)
1257+
dek_client = executor.client
1258+
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
1259+
obj_bytes = await ser(obj, ser_ctx)
1260+
1261+
deser = await AsyncAvroDeserializer(client, rule_conf=rule_conf)
1262+
executor.client = dek_client
1263+
obj2 = await deser(obj_bytes, ser_ctx)
1264+
assert obj == obj2
1265+
1266+
12051267
async def test_avro_encryption_deterministic():
12061268
executor = FieldEncryptionExecutor.register_with_clock(FakeClock())
12071269

tests/schema_registry/_sync/test_avro_serdes.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1202,6 +1202,68 @@ def test_avro_payload_encryption():
12021202
assert obj == obj2
12031203

12041204

1205+
def test_avro_encryption_alternate_keks():
1206+
executor = EncryptionExecutor.register_with_clock(FakeClock())
1207+
1208+
conf = {'url': _BASE_URL}
1209+
client = SchemaRegistryClient.new_client(conf)
1210+
ser_conf = {'auto.register.schemas': False, 'use.latest.version': True}
1211+
rule_conf = {'secret': 'mysecret', 'encrypt.alternate.kms.key.ids': 'mykey2,mykey3'}
1212+
schema = {
1213+
'type': 'record',
1214+
'name': 'test',
1215+
'fields': [
1216+
{'name': 'intField', 'type': 'int'},
1217+
{'name': 'doubleField', 'type': 'double'},
1218+
{'name': 'stringField', 'type': 'string', 'confluent:tags': ['PII']},
1219+
{'name': 'booleanField', 'type': 'boolean'},
1220+
{'name': 'bytesField', 'type': 'bytes', 'confluent:tags': ['PII']},
1221+
]
1222+
}
1223+
1224+
rule = Rule(
1225+
"test-encrypt",
1226+
"",
1227+
RuleKind.TRANSFORM,
1228+
RuleMode.WRITEREAD,
1229+
"ENCRYPT_PAYLOAD",
1230+
None,
1231+
RuleParams({
1232+
"encrypt.kek.name": "kek1",
1233+
"encrypt.kms.type": "local-kms",
1234+
"encrypt.kms.key.id": "mykey"
1235+
}),
1236+
None,
1237+
None,
1238+
"ERROR,NONE",
1239+
False
1240+
)
1241+
client.register_schema(_SUBJECT, Schema(
1242+
json.dumps(schema),
1243+
"AVRO",
1244+
[],
1245+
None,
1246+
RuleSet(None, None, [rule])
1247+
))
1248+
1249+
obj = {
1250+
'intField': 123,
1251+
'doubleField': 45.67,
1252+
'stringField': 'hi',
1253+
'booleanField': True,
1254+
'bytesField': b'foobar',
1255+
}
1256+
ser = AvroSerializer(client, schema_str=None, conf=ser_conf, rule_conf=rule_conf)
1257+
dek_client = executor.client
1258+
ser_ctx = SerializationContext(_TOPIC, MessageField.VALUE)
1259+
obj_bytes = ser(obj, ser_ctx)
1260+
1261+
deser = AvroDeserializer(client, rule_conf=rule_conf)
1262+
executor.client = dek_client
1263+
obj2 = deser(obj_bytes, ser_ctx)
1264+
assert obj == obj2
1265+
1266+
12051267
def test_avro_encryption_deterministic():
12061268
executor = FieldEncryptionExecutor.register_with_clock(FakeClock())
12071269

0 commit comments

Comments
 (0)