@@ -299,6 +299,7 @@ def __init__(self, version=None, header_size=IMAGE_HEADER_SIZE,
299299 self .enctlv_len = 0
300300 self .max_align = max (DEFAULT_MAX_ALIGN , align ) if max_align is None else int (max_align )
301301 self .non_bootable = non_bootable
302+ self .key_ids = None
302303
303304 if self .max_align == DEFAULT_MAX_ALIGN :
304305 self .boot_magic = bytes ([
@@ -472,33 +473,41 @@ def ecies_hkdf(self, enckey, plainkey, hmac_sha_alg):
472473 format = PublicFormat .Raw )
473474 return cipherkey , ciphermac , pubk
474475
475- def create (self , key , public_key_format , enckey , dependencies = None ,
476+ def create (self , keys , public_key_format , enckey , dependencies = None ,
476477 sw_type = None , custom_tlvs = None , compression_tlvs = None ,
477478 compression_type = None , encrypt_keylen = 128 , clear = False ,
478479 fixed_sig = None , pub_key = None , vector_to_sign = None ,
479480 user_sha = 'auto' , hmac_sha = 'auto' , is_pure = False , keep_comp_size = False ,
480481 dont_encrypt = False ):
481482 self .enckey = enckey
482483
483- # key decides on sha, then pub_key; of both are none default is used
484- check_key = key if key is not None else pub_key
484+ # key decides on sha, then pub_key; if both are none default is used
485+ check_key = keys [ 0 ] if keys [ 0 ] is not None else pub_key
485486 hash_algorithm , hash_tlv = key_and_user_sha_to_alg_and_tlv (check_key , user_sha , is_pure )
486487
487488 # Calculate the hash of the public key
488- if key is not None :
489- pub = key .get_public_bytes ()
490- sha = hash_algorithm ()
491- sha .update (pub )
492- pubbytes = sha .digest ()
493- elif pub_key is not None :
494- if hasattr (pub_key , 'sign' ):
495- print (os .path .basename (__file__ ) + ": sign the payload" )
496- pub = pub_key .get_public_bytes ()
497- sha = hash_algorithm ()
498- sha .update (pub )
499- pubbytes = sha .digest ()
489+ pub_digests = []
490+ pub_list = []
491+
492+ if keys is None :
493+ if pub_key is not None :
494+ if hasattr (pub_key , 'sign' ):
495+ print (os .path .basename (__file__ ) + ": sign the payload" )
496+ pub = pub_key .get_public_bytes ()
497+ sha = hash_algorithm ()
498+ sha .update (pub )
499+ pubbytes = sha .digest ()
500+ else :
501+ pubbytes = bytes (hashlib .sha256 ().digest_size )
500502 else :
501- pubbytes = bytes (hashlib .sha256 ().digest_size )
503+ for key in keys or []:
504+ pub = key .get_public_bytes ()
505+ sha = hash_algorithm ()
506+ sha .update (pub )
507+ pubbytes = sha .digest ()
508+ pub_digests .append (pubbytes )
509+ pub_list .append (pub )
510+
502511
503512 protected_tlv_size = 0
504513
@@ -526,10 +535,14 @@ def create(self, key, public_key_format, enckey, dependencies=None,
526535 # value later.
527536 digest = bytes (hash_algorithm ().digest_size )
528537
538+ if pub_digests :
539+ boot_pub_digest = pub_digests [0 ]
540+ else :
541+ boot_pub_digest = pubbytes
529542 # Create CBOR encoded boot record
530543 boot_record = create_sw_component_data (sw_type , image_version ,
531544 hash_tlv , digest ,
532- pubbytes )
545+ boot_pub_digest )
533546
534547 protected_tlv_size += TLV_SIZE + len (boot_record )
535548
@@ -648,33 +661,39 @@ def create(self, key, public_key_format, enckey, dependencies=None,
648661 print (os .path .basename (__file__ ) + ': export digest' )
649662 return
650663
651- if self . key_ids is not None :
652- self . _add_key_id_tlv_to_unprotected ( tlv , self . key_ids [ 0 ] )
664+ if fixed_sig is not None and keys is not None :
665+ raise click . UsageError ( "Can not sign using key and provide fixed-signature at the same time" )
653666
654- if key is not None or fixed_sig is not None :
655- if public_key_format == 'hash' :
656- tlv .add ('KEYHASH' , pubbytes )
657- else :
658- tlv .add ('PUBKEY' , pub )
667+ if fixed_sig is not None :
668+ tlv .add (pub_key .sig_tlv (), fixed_sig ['value' ])
669+ self .signatures [0 ] = fixed_sig ['value' ]
670+ else :
671+ # Multi-signature handling: iterate through each provided key and sign.
672+ self .signatures = []
673+ for i , key in enumerate (keys ):
674+ # If key IDs are provided, and we have enough for this key, add it first.
675+ if self .key_ids is not None and len (self .key_ids ) > i :
676+ # Convert key id (an integer) to 4-byte defined endian bytes.
677+ kid_bytes = self .key_ids [i ].to_bytes (4 , self .endian )
678+ tlv .add ('KEYID' , kid_bytes ) # Using the TLV tag that corresponds to key IDs.
679+
680+ if public_key_format == 'hash' :
681+ tlv .add ('KEYHASH' , pub_digests [i ])
682+ else :
683+ tlv .add ('PUBKEY' , pub_list [i ])
659684
660- if key is not None and fixed_sig is None :
661685 # `sign` expects the full image payload (hashing done
662686 # internally), while `sign_digest` expects only the digest
663687 # of the payload
664-
665688 if hasattr (key , 'sign' ):
666689 print (os .path .basename (__file__ ) + ": sign the payload" )
667690 sig = key .sign (bytes (self .payload ))
668691 else :
669692 print (os .path .basename (__file__ ) + ": sign the digest" )
670693 sig = key .sign_digest (message )
671694 tlv .add (key .sig_tlv (), sig )
672- self .signature = sig
673- elif fixed_sig is not None and key is None :
674- tlv .add (pub_key .sig_tlv (), fixed_sig ['value' ])
675- self .signature = fixed_sig ['value' ]
676- else :
677- raise click .UsageError ("Can not sign using key and provide fixed-signature at the same time" )
695+ self .signatures .append (sig )
696+
678697
679698 # At this point the image was hashed + signed, we can remove the
680699 # protected TLVs from the payload (will be re-added later)
@@ -738,7 +757,7 @@ def get_struct_endian(self):
738757 return STRUCT_ENDIAN_DICT [self .endian ]
739758
740759 def get_signature (self ):
741- return self .signature
760+ return self .signatures
742761
743762 def get_infile_data (self ):
744763 return self .infile_data
@@ -848,75 +867,99 @@ def verify(imgfile, key):
848867 if magic != IMAGE_MAGIC :
849868 return VerifyResult .INVALID_MAGIC , None , None , None
850869
870+ # Locate the first TLV info header
851871 tlv_off = header_size + img_size
852872 tlv_info = b [tlv_off :tlv_off + TLV_INFO_SIZE ]
853873 magic , tlv_tot = struct .unpack ('HH' , tlv_info )
874+
875+ # If it's the protected-TLV block, skip it
854876 if magic == TLV_PROT_INFO_MAGIC :
855- tlv_off += tlv_tot
877+ tlv_off += TLV_INFO_SIZE + tlv_tot
856878 tlv_info = b [tlv_off :tlv_off + TLV_INFO_SIZE ]
857879 magic , tlv_tot = struct .unpack ('HH' , tlv_info )
858880
859881 if magic != TLV_INFO_MAGIC :
860882 return VerifyResult .INVALID_TLV_INFO_MAGIC , None , None , None
861883
862- # This is set by existence of TLV SIG_PURE
863- is_pure = False
884+ # Define the unprotected-TLV window
885+ unprot_off = tlv_off + TLV_INFO_SIZE
886+ unprot_end = unprot_off + tlv_tot
864887
865- prot_tlv_size = tlv_off
866- hash_region = b [:prot_tlv_size ]
867- tlv_end = tlv_off + tlv_tot
868- tlv_off += TLV_INFO_SIZE # skip tlv info
888+ # Region up to the start of unprotected TLVs is hashed
889+ prot_tlv_end = unprot_off - TLV_INFO_SIZE
890+ hash_region = b [:prot_tlv_end ]
869891
870- # First scan all TLVs in search of SIG_PURE
871- while tlv_off < tlv_end :
872- tlv = b [tlv_off :tlv_off + TLV_SIZE ]
892+ # This is set by existence of TLV SIG_PURE
893+ is_pure = False
894+ scan_off = unprot_off
895+ while scan_off < unprot_end :
896+ tlv = b [scan_off :scan_off + TLV_SIZE ]
873897 tlv_type , _ , tlv_len = struct .unpack ('BBH' , tlv )
874898 if tlv_type == TLV_VALUES ['SIG_PURE' ]:
875899 is_pure = True
876900 break
877- tlv_off += TLV_SIZE + tlv_len
901+ scan_off += TLV_SIZE + tlv_len
878902
903+ if key is not None and not isinstance (key , list ):
904+ key = [key ]
905+
906+ verify_results = []
907+ scan_off = unprot_off
879908 digest = None
880- tlv_off = prot_tlv_size
881- tlv_end = tlv_off + tlv_tot
882- tlv_off += TLV_INFO_SIZE # skip tlv info
883- while tlv_off < tlv_end :
884- tlv = b [tlv_off : tlv_off + TLV_SIZE ]
909+ prot_tlv_size = unprot_off - TLV_INFO_SIZE
910+
911+ # Verify hash and signatures
912+ while scan_off < unprot_end :
913+ tlv = b [scan_off : scan_off + TLV_SIZE ]
885914 tlv_type , _ , tlv_len = struct .unpack ('BBH' , tlv )
886915 if is_sha_tlv (tlv_type ):
887- if not tlv_matches_key_type (tlv_type , key ):
916+ if not tlv_matches_key_type (tlv_type , key [ 0 ] ):
888917 return VerifyResult .KEY_MISMATCH , None , None , None
889- off = tlv_off + TLV_SIZE
918+ off = scan_off + TLV_SIZE
890919 digest = get_digest (tlv_type , hash_region )
891- if digest == b [off :off + tlv_len ]:
892- if key is None :
893- return VerifyResult .OK , version , digest , None
894- else :
895- return VerifyResult .INVALID_HASH , None , None , None
896- elif not is_pure and key is not None and tlv_type == TLV_VALUES [key .sig_tlv ()]:
897- off = tlv_off + TLV_SIZE
898- tlv_sig = b [off :off + tlv_len ]
899- payload = b [:prot_tlv_size ]
900- try :
901- if hasattr (key , 'verify' ):
902- key .verify (tlv_sig , payload )
903- else :
904- key .verify_digest (tlv_sig , digest )
905- return VerifyResult .OK , version , digest , None
906- except InvalidSignature :
907- # continue to next TLV
908- pass
920+ if digest != b [off :off + tlv_len ]:
921+ verify_results .append (("Digest" , "INVALID_HASH" ))
922+
923+ elif not is_pure and key is not None and tlv_type == TLV_VALUES [key [0 ].sig_tlv ()]:
924+ for idx , k in enumerate (key ):
925+ if tlv_type == TLV_VALUES [k .sig_tlv ()]:
926+ off = scan_off + TLV_SIZE
927+ tlv_sig = b [off :off + tlv_len ]
928+ payload = b [:prot_tlv_size ]
929+ try :
930+ if hasattr (k , 'verify' ):
931+ k .verify (tlv_sig , payload )
932+ else :
933+ k .verify_digest (tlv_sig , digest )
934+ verify_results .append ((f"Key { idx } " , "OK" ))
935+ break
936+ except InvalidSignature :
937+ # continue to next TLV
938+ verify_results .append ((f"Key { idx } " , "INVALID_SIGNATURE" ))
939+ continue
940+
909941 elif is_pure and key is not None and tlv_type in ALLOWED_PURE_SIG_TLVS :
910- off = tlv_off + TLV_SIZE
942+ # pure signature verification
943+ off = scan_off + TLV_SIZE
911944 tlv_sig = b [off :off + tlv_len ]
945+ k = key [0 ]
912946 try :
913- key .verify_digest (tlv_sig , hash_region )
947+ k .verify_digest (tlv_sig , hash_region )
914948 return VerifyResult .OK , version , None , tlv_sig
915949 except InvalidSignature :
916- # continue to next TLV
917- pass
918- tlv_off += TLV_SIZE + tlv_len
919- return VerifyResult .INVALID_SIGNATURE , None , None , None
950+ return VerifyResult .INVALID_SIGNATURE , None , None , None
951+
952+ scan_off += TLV_SIZE + tlv_len
953+ # Now print out the verification results:
954+ for k , result in verify_results :
955+ print (f"{ k } : { result } " )
956+
957+ # Decide on a final return (for example, OK only if at least one signature is valid)
958+ if any (result == "OK" for _ , result in verify_results ):
959+ return VerifyResult .OK , version , digest , None
960+ else :
961+ return VerifyResult .INVALID_SIGNATURE , None , None , None
962+
920963
921964 def set_key_ids (self , key_ids ):
922965 """Set list of key IDs (integers) to be inserted before each signature."""
0 commit comments