16
16
import io
17
17
import logging
18
18
import time
19
- from typing import Optional , Tuple , Any
19
+ from typing import Optional , Tuple , Any , List
20
20
21
21
from tink import aead , daead , KmsClient , kms_client_from_uri , \
22
22
register_kms_client , TinkError
45
45
ENCRYPT_KMS_TYPE = "encrypt.kms.type"
46
46
ENCRYPT_DEK_ALGORITHM = "encrypt.dek.algorithm"
47
47
ENCRYPT_DEK_EXPIRY_DAYS = "encrypt.dek.expiry.days"
48
+ ENCRYPT_ALTERNATE_KMS_KEY_IDS = "encrypt.alternate.kms.key.ids"
48
49
49
50
MILLIS_IN_DAY = 24 * 60 * 60 * 1000
50
51
@@ -279,7 +280,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek:
279
280
raise RuleError (f"no dek found for { dek_id .kek_name } during consume" )
280
281
encrypted_dek = None
281
282
if not kek .shared :
282
- primitive = self . _get_aead (self ._executor .config , self ._kek )
283
+ primitive = AeadWrapper (self ._executor .config , self ._kek )
283
284
raw_dek = self ._cryptor .generate_key ()
284
285
encrypted_dek = primitive .encrypt (raw_dek , self ._cryptor .EMPTY_AAD )
285
286
new_version = dek .version + 1 if is_expired else 1
@@ -293,7 +294,7 @@ def _get_or_create_dek(self, ctx: RuleContext, version: Optional[int]) -> Dek:
293
294
key_bytes = dek .get_key_material_bytes ()
294
295
if key_bytes is None :
295
296
if primitive is None :
296
- primitive = self . _get_aead (self ._executor .config , self ._kek )
297
+ primitive = AeadWrapper (self ._executor .config , self ._kek )
297
298
encrypted_dek = dek .get_encrypted_key_material_bytes ()
298
299
raw_dek = primitive .decrypt (encrypted_dek , self ._cryptor .EMPTY_AAD )
299
300
dek .set_key_material (raw_dek )
@@ -410,8 +411,51 @@ def _to_object(self, field_type: FieldType, value: bytes) -> Any:
410
411
return value
411
412
return None
412
413
413
- def _get_aead (self , config : dict , kek : Kek ) -> aead .Aead :
414
- kek_url = kek .kms_type + "://" + kek .kms_key_id
414
+
415
+ class AeadWrapper (aead .Aead ):
416
+ def __init__ (self , config : dict , kek : Kek ):
417
+ self ._config = config
418
+ self ._kek = kek
419
+ self ._kms_key_ids = self ._get_kms_key_ids ()
420
+
421
+ def encrypt (self , plaintext : bytes , associated_data : bytes ) -> bytes :
422
+ for index , kms_key_id in enumerate (self ._kms_key_ids ):
423
+ try :
424
+ aead = self ._get_aead (self ._config , self ._kek .kms_type , kms_key_id )
425
+ return aead .encrypt (plaintext , associated_data )
426
+ except Exception as e :
427
+ log .warning ("failed to encrypt with kek %s and kms key id %s" ,
428
+ self ._kek .name , kms_key_id )
429
+ if index == len (self ._kms_key_ids ) - 1 :
430
+ raise RuleError (f"failed to encrypt with all KEKs for { self ._kek .name } " ) from e
431
+ raise RuleError ("No KEK found for encryption" )
432
+
433
+ def decrypt (self , ciphertext : bytes , associated_data : bytes ) -> bytes :
434
+ for index , kms_key_id in enumerate (self ._kms_key_ids ):
435
+ try :
436
+ aead = self ._get_aead (self ._config , self ._kek .kms_type , kms_key_id )
437
+ return aead .decrypt (ciphertext , associated_data )
438
+ except Exception as e :
439
+ log .warning ("failed to decrypt with kek %s and kms key id %s" ,
440
+ self ._kek .name , kms_key_id )
441
+ if index == len (self ._kms_key_ids ) - 1 :
442
+ raise RuleError (f"failed to decrypt with all KEKs for { self ._kek .name } " ) from e
443
+ raise RuleError ("No KEK found for decryption" )
444
+
445
+ def _get_kms_key_ids (self ) -> List [str ]:
446
+ kms_key_ids = [self ._kek .kms_key_id ]
447
+ alternate_kms_key_ids = None
448
+ if self ._kek .kms_props is not None :
449
+ alternate_kms_key_ids = self ._kek .kms_props .properties .get (ENCRYPT_ALTERNATE_KMS_KEY_IDS )
450
+ if alternate_kms_key_ids is None :
451
+ alternate_kms_key_ids = self ._config .get (ENCRYPT_ALTERNATE_KMS_KEY_IDS )
452
+ if alternate_kms_key_ids is not None :
453
+ # Split the comma-separated list of alternate KMS key IDs and append to kms_key_ids
454
+ kms_key_ids .extend ([id .strip () for id in alternate_kms_key_ids .split (',' ) if id .strip ()])
455
+ return kms_key_ids
456
+
457
+ def _get_aead (self , config : dict , kms_type : str , kms_key_id : str ) -> aead .Aead :
458
+ kek_url = kms_type + "://" + kms_key_id
415
459
kms_client = self ._get_kms_client (config , kek_url )
416
460
return kms_client .get_aead (kek_url )
417
461
0 commit comments