@@ -191,6 +191,8 @@ def constant_segment_with_tensor_alignment(
191191 # the end of the file.
192192 self .assertGreaterEqual (eh .segment_base_offset , eh .program_size )
193193 self .assertLess (eh .segment_base_offset , len (pte_data ))
194+ # Segment data_size should be non-zero since there are segments.
195+ self .assertGreater (eh .segment_data_size , 0 )
194196
195197 # Peek inside the actual flatbuffer data to see the segments.
196198 program_with_segments = _json_to_program (_program_flatbuffer_to_json (pte_data ))
@@ -232,6 +234,8 @@ def constant_segment_with_tensor_alignment(
232234 # Check segment data.
233235 offsets = subsegment_offsets .offsets
234236 segment_data : bytes = pte_data [eh .segment_base_offset :]
237+ # Check segment data size.
238+ self .assertEqual (len (segment_data ), eh .segment_data_size )
235239
236240 # tensor[1]: padding.
237241 self .assertEqual (
@@ -514,6 +518,8 @@ def test_round_trip_with_segments(self) -> None:
514518 # the end of the file.
515519 self .assertGreaterEqual (eh .segment_base_offset , eh .program_size )
516520 self .assertLess (eh .segment_base_offset , len (pte_data ))
521+ # Segment data size should be non-zero since there are segments.
522+ self .assertGreater (eh .segment_data_size , 0 )
517523
518524 # Peek inside the actual flatbuffer data to see the segments. Note that
519525 # this also implicity tests the case where we try parsing the entire
@@ -566,6 +572,8 @@ def test_round_trip_with_segments(self) -> None:
566572 # Now that we've shown that the base offset is correct, slice off the
567573 # front so that all segment offsets are relative to zero.
568574 segment_data : bytes = pte_data [segment_base_offset :]
575+ # Check segment data size.
576+ self .assertEqual (len (segment_data ), eh .segment_data_size )
569577
570578 # End of the first segment. It's much smaller than the alignment,
571579 # so we know that it's followed by zeros.
@@ -729,6 +737,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
729737 # the end of the file.
730738 self .assertGreaterEqual (eh .segment_base_offset , eh .program_size )
731739 self .assertLess (eh .segment_base_offset , len (pte_data ))
740+ # Segment data size should be non-zero since there are segments.
741+ self .assertGreater (eh .segment_data_size , 0 )
732742
733743 # Peek inside the actual flatbuffer data to see the segments.
734744 program_with_segments = _json_to_program (_program_flatbuffer_to_json (pte_data ))
@@ -811,6 +821,8 @@ def test_constant_delegate_and_named_data_segments(self) -> None:
811821 # Now that we've shown that the base offset is correct, slice off the
812822 # front so that all segment offsets are relative to zero.
813823 segment_data : bytes = pte_data [segment_base_offset :]
824+ # Check segment data size.
825+ self .assertEqual (len (segment_data ), eh .segment_data_size )
814826
815827 # Check segment[0] for constants.
816828 offsets = subsegment_offsets .offsets
@@ -925,6 +937,8 @@ def test_named_data_segments(self) -> None:
925937 # the end of the file.
926938 self .assertGreaterEqual (eh .segment_base_offset , eh .program_size )
927939 self .assertLess (eh .segment_base_offset , len (pte_data ))
940+ # Segment data size should be non-zero since there are segments.
941+ self .assertGreater (eh .segment_data_size , 0 )
928942
929943 # Peek inside the actual flatbuffer data to see the named data segments.
930944 program_with_segments = _json_to_program (_program_flatbuffer_to_json (pte_data ))
@@ -958,6 +972,9 @@ def test_named_data_segments(self) -> None:
958972
959973 # Check the pte data for buffer values.
960974 segment_data : bytes = pte_data [eh .segment_base_offset :]
975+ # Check segment data size.
976+ self .assertEqual (len (segment_data ), eh .segment_data_size )
977+
961978 self .assertEqual (
962979 segment_data [
963980 segment_table [0 ].offset : segment_table [0 ].offset
@@ -985,18 +1002,21 @@ def test_named_data_segments(self) -> None:
9851002# the example data.
9861003EXAMPLE_PROGRAM_SIZE : int = 0x1122112233443344
9871004EXAMPLE_SEGMENT_BASE_OFFSET : int = 0x5566556677887788
1005+ EXAMPLE_SEGMENT_DATA_SIZE : int = 0x5544554433223322
9881006# This data is intentionally fragile. If the header layout or magic changes,
9891007# this test must change too. The layout of the header is a contract, not an
9901008# implementation detail.
9911009EXAMPLE_HEADER_DATA : bytes = (
9921010 # Magic bytes
9931011 b"eh00"
9941012 # uint32_t header size (little endian)
995- + b"\x18 \x00 \x00 \x00 "
1013+ + b"\x20 \x00 \x00 \x00 "
9961014 # uint64_t program size
9971015 + b"\x44 \x33 \x44 \x33 \x22 \x11 \x22 \x11 "
9981016 # uint64_t segment base offset
9991017 + b"\x88 \x77 \x88 \x77 \x66 \x55 \x66 \x55 "
1018+ # uint64_t segment data size
1019+ + b"\x22 \x33 \x22 \x33 \x44 \x55 \x44 \x55 "
10001020)
10011021
10021022
@@ -1005,6 +1025,7 @@ def test_to_bytes(self) -> None:
10051025 eh = _ExtendedHeader (
10061026 program_size = EXAMPLE_PROGRAM_SIZE ,
10071027 segment_base_offset = EXAMPLE_SEGMENT_BASE_OFFSET ,
1028+ segment_data_size = EXAMPLE_SEGMENT_DATA_SIZE ,
10081029 )
10091030 self .assertTrue (eh .is_valid ())
10101031 self .assertEqual (eh .to_bytes (), EXAMPLE_HEADER_DATA )
@@ -1013,6 +1034,7 @@ def test_to_bytes_with_non_defaults(self) -> None:
10131034 eh = _ExtendedHeader (
10141035 program_size = EXAMPLE_PROGRAM_SIZE ,
10151036 segment_base_offset = EXAMPLE_SEGMENT_BASE_OFFSET ,
1037+ segment_data_size = EXAMPLE_SEGMENT_DATA_SIZE ,
10161038 # Override the default magic and length, to demonstrate that this
10171039 # does not affect the serialized header.
10181040 magic = b"ABCD" ,
@@ -1036,6 +1058,7 @@ def test_from_bytes_valid(self) -> None:
10361058 self .assertEqual (eh .length , _ExtendedHeader .EXPECTED_LENGTH )
10371059 self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
10381060 self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1061+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
10391062
10401063 def test_from_bytes_with_more_data_than_necessary (self ) -> None :
10411064 # Pass in more data than necessary to parse the header.
@@ -1049,6 +1072,7 @@ def test_from_bytes_with_more_data_than_necessary(self) -> None:
10491072 self .assertEqual (eh .length , _ExtendedHeader .EXPECTED_LENGTH )
10501073 self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
10511074 self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1075+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
10521076
10531077 def test_from_bytes_larger_than_needed_header_size_field (self ) -> None :
10541078 # Simulate a backwards-compatibility situation. Parse a header
@@ -1059,11 +1083,13 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
10591083 # Magic bytes
10601084 b"eh00"
10611085 # uint32_t header size (little endian)
1062- + b"\x1c \x00 \x00 \x00 " # Longer than expected
1086+ + b"\x21 \x00 \x00 \x00 " # Longer than expected
10631087 # uint64_t program size
10641088 + b"\x44 \x33 \x44 \x33 \x22 \x11 \x22 \x11 "
10651089 # uint64_t segment base offset
10661090 + b"\x88 \x77 \x88 \x77 \x66 \x55 \x66 \x55 "
1091+ # uint64_t segment data size
1092+ + b"\x22 \x33 \x22 \x33 \x44 \x55 \x44 \x55 "
10671093 # uint32_t new field (ignored)
10681094 + b"\xff \xee \xff \xee "
10691095 )
@@ -1075,9 +1101,10 @@ def test_from_bytes_larger_than_needed_header_size_field(self) -> None:
10751101 self .assertTrue (eh .is_valid ())
10761102
10771103 self .assertEqual (eh .magic , _ExtendedHeader .EXPECTED_MAGIC )
1078- self .assertEqual (eh .length , 28 )
1104+ self .assertEqual (eh .length , 33 )
10791105 self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
10801106 self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1107+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
10811108
10821109 def test_from_bytes_not_enough_data_fails (self ) -> None :
10831110 # Parsing a truncated prefix should fail.
@@ -1090,11 +1117,13 @@ def test_from_bytes_invalid_magic(self) -> None:
10901117 # Magic bytes
10911118 b"ABCD" # Invalid
10921119 # uint32_t header size (little endian)
1093- + b"\x18 \x00 \x00 \x00 "
1120+ + b"\x20 \x00 \x00 \x00 "
10941121 # uint64_t program size
10951122 + b"\x44 \x33 \x44 \x33 \x22 \x11 \x22 \x11 "
10961123 # uint64_t segment base offset
10971124 + b"\x88 \x77 \x88 \x77 \x66 \x55 \x66 \x55 "
1125+ # uint64_t segment data size
1126+ + b"\x22 \x33 \x22 \x33 \x44 \x55 \x44 \x55 "
10981127 )
10991128
11001129 # Parse the serialized extended header.
@@ -1109,6 +1138,7 @@ def test_from_bytes_invalid_magic(self) -> None:
11091138 self .assertEqual (eh .length , _ExtendedHeader .EXPECTED_LENGTH )
11101139 self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
11111140 self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1141+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
11121142
11131143 def test_from_bytes_invalid_length (self ) -> None :
11141144 # An invalid serialized header
@@ -1121,6 +1151,8 @@ def test_from_bytes_invalid_length(self) -> None:
11211151 + b"\x44 \x33 \x44 \x33 \x22 \x11 \x22 \x11 "
11221152 # uint64_t segment base offset
11231153 + b"\x88 \x77 \x88 \x77 \x66 \x55 \x66 \x55 "
1154+ # uint64_t segment data size
1155+ + b"\x22 \x33 \x22 \x33 \x44 \x55 \x44 \x55 "
11241156 )
11251157
11261158 # Parse the serialized extended header.
@@ -1135,3 +1167,4 @@ def test_from_bytes_invalid_length(self) -> None:
11351167 self .assertEqual (eh .length , 16 )
11361168 self .assertEqual (eh .program_size , EXAMPLE_PROGRAM_SIZE )
11371169 self .assertEqual (eh .segment_base_offset , EXAMPLE_SEGMENT_BASE_OFFSET )
1170+ self .assertEqual (eh .segment_data_size , EXAMPLE_SEGMENT_DATA_SIZE )
0 commit comments