diff --git a/src/beacon_tests.zig b/src/beacon_tests.zig index 047228e..0744957 100644 --- a/src/beacon_tests.zig +++ b/src/beacon_tests.zig @@ -779,7 +779,7 @@ test "hashTreeRoot for pointer types" { try hashTreeRoot(Sha256, *u32, &value, &hash, std.testing.allocator); var deserialized: u32 = undefined; - try deserialize(u32, &hash, &deserialized, std.testing.allocator); + try deserialize(u32, hash[0..4], &deserialized, std.testing.allocator); try expect(deserialized == value); } @@ -790,7 +790,7 @@ test "hashTreeRoot for pointer types" { try hashTreeRoot(Sha256, *[4]u8, values_ptr, &hash, std.testing.allocator); var deserialized: [4]u8 = undefined; - try deserialize([4]u8, &hash, &deserialized, std.testing.allocator); + try deserialize([4]u8, hash[0..4], &deserialized, std.testing.allocator); try expect(std.mem.eql(u8, &deserialized, values_ptr)); } diff --git a/src/lib.zig b/src/lib.zig index 5abb2bf..6a72140 100644 --- a/src/lib.zig +++ b/src/lib.zig @@ -119,6 +119,107 @@ pub fn isFixedSizeObject(T: type) !bool { return true; } +/// Returns the maximum possible serialized byte length for type `T`. +/// Useful for pre-allocating buffers or validating input bounds. +/// For variable-length types (e.g. slice), use a type that encodes max length (e.g. List(T, N)) or returns error. +pub fn maxInLength(T: type) !usize { + if (comptime std.meta.hasFn(T, "maxInLength")) { + return T.maxInLength(); + } + + const info = @typeInfo(T); + return switch (info) { + .int => @sizeOf(T), + .bool => @as(usize, 1), + .null => @as(usize, 0), + .array => |array| if (array.child == bool) + (array.len + 7) / 8 + else blk: { + const child_max = try maxInLength(array.child); + if (try isFixedSizeObject(array.child)) { + break :blk array.len * child_max; + } else { + break :blk array.len * child_max + 4 * array.len; + } + }, + .optional => 1 + try maxInLength(info.optional.child), + .pointer => |ptr| switch (ptr.size) { + .slice => error.NoMaxInLengthAvailable, + .one => maxInLength(ptr.child), + else => error.NoMaxInLengthAvailable, + }, + .@"struct" => |str| blk: { + var total: usize = 0; + inline for (str.fields) |field| { + if (try isFixedSizeObject(field.type)) { + total += try maxInLength(field.type); + } else { + total += 4 + try maxInLength(field.type); + } + } + break :blk total; + }, + .@"union" => |u| blk: { + if (u.tag_type == null) return error.UnionIsNotTagged; + var m: usize = 0; + inline for (u.fields) |f| { + const n = try maxInLength(f.type); + if (n > m) m = n; + } + break :blk 1 + m; + }, + else => error.NoMaxInLengthAvailable, + }; +} + +/// Returns the minimum possible serialized byte length for type `T`. +/// Used together with maxInLength to validate input bounds before deserializing. +pub fn minInLength(T: type) !usize { + if (comptime std.meta.hasFn(T, "minInLength")) { + return T.minInLength(); + } + + const info = @typeInfo(T); + return switch (info) { + .int => @sizeOf(T), + .bool => @as(usize, 1), + .null => @as(usize, 0), + .array => |array| if (array.child == bool) + (array.len + 7) / 8 + else if (try isFixedSizeObject(array.child)) + array.len * try minInLength(array.child) + else + array.len * @sizeOf(u32) + array.len * try minInLength(array.child), + .optional => 1, + .pointer => |ptr| switch (ptr.size) { + .slice => error.NoMinInLengthAvailable, + .one => minInLength(ptr.child), + else => error.NoMinInLengthAvailable, + }, + .@"struct" => |str| blk: { + var total: usize = 0; + inline for (str.fields) |field| { + if (try isFixedSizeObject(field.type)) { + total += try minInLength(field.type); + } else { + total += 4 + try minInLength(field.type); + } + } + break :blk total; + }, + .@"union" => |u| blk: { + if (u.tag_type == null) return error.UnionIsNotTagged; + var m: usize = std.math.maxInt(usize); + inline for (u.fields) |f| { + const n = try minInLength(f.type); + if (n < m) m = n; + } + break :blk 1 + m; + }, + else => error.NoMinInLengthAvailable, + }; +} + /// Provides the generic serialization of any `data` var to SSZ. The /// serialization is written to the `ArrayList` `l`. pub fn serialize(T: type, data: T, l: *ArrayList(u8), allocator: Allocator) !void { @@ -312,12 +413,29 @@ pub fn serialize(T: type, data: T, l: *ArrayList(u8), allocator: Allocator) !voi } } -/// Takes a byte array containing the serialized payload of type `T` (with -/// possible trailing data) and deserializes it into the `T` object pointed -/// at by `out`. +/// Takes a byte array containing the serialized payload of type `T` and +/// deserializes it into the `T` object pointed at by `out`. +/// The payload must be within [minInLength, maxInLength] bounds for `T`. pub fn deserialize(T: type, serialized: []const u8, out: *T, allocator: ?Allocator) !void { + const has_custom_decode = comptime std.meta.hasFn(T, "sszDecode"); + const enforce_min = !has_custom_decode or comptime std.meta.hasFn(T, "minInLength"); + const enforce_max = !has_custom_decode or comptime std.meta.hasFn(T, "maxInLength"); + + // Bounds check: ensure serialized length is within [minInLength, maxInLength] + const min_len: ?usize = if (enforce_min) blk: { + const m = minInLength(T) catch break :blk null; + break :blk m; + } else null; + if (min_len) |m| if (serialized.len < m) return error.PayloadTooSmall; + + const max_len: ?usize = if (enforce_max) blk: { + const m = maxInLength(T) catch break :blk null; + break :blk m; + } else null; + if (max_len) |m| if (serialized.len > m) return error.PayloadTooLarge; + // shortcut if the type implements its own decode method - if (comptime std.meta.hasFn(T, "sszDecode")) { + if (has_custom_decode) { return T.sszDecode(serialized, out, allocator); } @@ -497,7 +615,7 @@ pub fn deserialize(T: type, serialized: []const u8, out: *T, allocator: ?Allocat .@"union" => { // Read the type index var union_index: u8 = undefined; - try deserialize(u8, serialized, &union_index, allocator); + try deserialize(u8, serialized[0..1], &union_index, allocator); // Use the index to figure out which type must // be deserialized. diff --git a/src/tests.zig b/src/tests.zig index 9d4691a..6e52c49 100644 --- a/src/tests.zig +++ b/src/tests.zig @@ -1121,6 +1121,88 @@ test "isFixedSizeObject correctly identifies List/Bitlist as variable-size" { try expect(!try isFixedSizeObject(StructWithList)); } +test "maxInLength for fixed and variable types" { + try expect(try libssz.maxInLength(u8) == 1); + try expect(try libssz.maxInLength(u64) == 8); + try expect(try libssz.maxInLength(bool) == 1); + try expect(try libssz.maxInLength([4]u8) == 4); + try expect(try libssz.maxInLength([10]bool) == (10 + 7) / 8); + + const ListU64 = utils.List(u64, 16); + try expect(try ListU64.maxInLength() == 16 * 8); + + const Bitlist32 = utils.Bitlist(32); + try expect(Bitlist32.maxInLength() == (32 + 7 + 1) / 8); + + const ListList = utils.List(utils.List(u8, 4), 2); + try expect(try ListList.maxInLength() == 2 * 4 + 2 * (4 * 1)); + + const S = struct { + a: u32, + b: [2]u8, + }; + try expect(try libssz.maxInLength(S) == 4 + 2); +} + +test "minInLength for fixed and variable types" { + try expect(try libssz.minInLength(u8) == 1); + try expect(try libssz.minInLength(u64) == 8); + try expect(try libssz.minInLength(bool) == 1); + try expect(try libssz.minInLength([4]u8) == 4); + try expect(try libssz.minInLength([10]bool) == (10 + 7) / 8); + + const ListU64 = utils.List(u64, 16); + try expect(ListU64.minInLength() == 0); + + const Bitlist32 = utils.Bitlist(32); + try expect(Bitlist32.minInLength() == 1); + + const S = struct { + a: u32, + b: [2]u8, + }; + try expect(try libssz.minInLength(S) == 4 + 2); + + const VarS = struct { + a: u32, + b: []const u8, + }; + _ = libssz.minInLength(VarS) catch |e| try expect(e == error.NoMinInLengthAvailable); +} + +test "deserialize rejects payload shorter than minInLength" { + var out_u32: u32 = undefined; + try expectError(error.PayloadTooSmall, deserialize(u32, &[_]u8{ 0x01, 0x02 }, &out_u32, null)); + + var out_bool: bool = undefined; + try expectError(error.PayloadTooSmall, deserialize(bool, &[_]u8{}, &out_bool, null)); + + var out_fixed: [4]u8 = undefined; + try expectError(error.PayloadTooSmall, deserialize([4]u8, &[_]u8{ 0x01, 0x02 }, &out_fixed, null)); +} + +test "minInLength/maxInLength for struct with List field" { + const S = struct { + id: u32, + data: utils.List(u8, 8), + }; + // min: 4 (u32) + 4 (offset for variable field) + 0 (empty list) = 8 + try expect(try libssz.minInLength(S) == 4 + 4 + 0); + // max: 4 (u32) + 4 (offset for variable field) + 8*1 (full list) = 16 + try expect(try libssz.maxInLength(S) == 4 + 4 + 8 * 1); +} + +test "deserialize rejects payload longer than maxInLength" { + var out_u32: u32 = undefined; + try expectError(error.PayloadTooLarge, deserialize(u32, &[_]u8{ 0x01, 0x02, 0x03, 0x04, 0x05 }, &out_u32, null)); + + var out_bool: bool = undefined; + try expectError(error.PayloadTooLarge, deserialize(bool, &[_]u8{ 0x00, 0x01 }, &out_bool, null)); + + var out_fixed: [2]u8 = undefined; + try expectError(error.PayloadTooLarge, deserialize([2]u8, &[_]u8{ 0x01, 0x02, 0x03 }, &out_fixed, null)); +} + test "zeam stf input" { const Bytes32 = [32]u8; const Bytes48 = [48]u8; @@ -1682,7 +1764,7 @@ test "List validation - size limits enforced" { 0x05, 0x00, 0x00, 0x00, // u32 = 5 }; - try std.testing.expectError(error.ListTooBig, deserialize(utils.List(u32, 3), &oversized_data, &list, std.testing.allocator)); + try std.testing.expectError(error.PayloadTooLarge, deserialize(utils.List(u32, 3), &oversized_data, &list, std.testing.allocator)); } } diff --git a/src/utils.zig b/src/utils.zig index 3c319df..53fa4ee 100644 --- a/src/utils.zig +++ b/src/utils.zig @@ -39,6 +39,19 @@ pub fn List(T: type, comptime N: usize) type { return false; } + /// Maximum serialized byte length for List(T, N) with at most N elements. + pub fn maxInLength() !usize { + if (try lib.isFixedSizeObject(Item)) { + return N * try lib.serializedFixedSize(Item); + } + return N * @sizeOf(u32) + N * try lib.maxInLength(Item); + } + + /// Minimum serialized byte length for List(T, N) (empty list). + pub fn minInLength() usize { + return 0; + } + pub fn sszDecode(serialized: []const u8, out: *Self, allocator: ?Allocator) !void { // BitList[N] or regular List[N]? const alloc = allocator orelse return error.AllocatorRequired; @@ -60,11 +73,6 @@ pub fn List(T: type, comptime N: usize) type { const pitch = try lib.serializedFixedSize(Self.Item); const n_items = serialized.len / pitch; - // Validate list size against maximum N - if (n_items > N) { - return error.ListTooBig; - } - for (0..n_items) |i| { var item: Self.Item = undefined; try deserialize(Self.Item, serialized[i * pitch .. (i + 1) * pitch], &item, allocator); @@ -283,6 +291,16 @@ pub fn Bitlist(comptime N: usize) type { return false; } + /// Maximum serialized byte length for Bitlist(N) (N bits + sentinel). + pub fn maxInLength() usize { + return (N + 7 + 1) / 8; + } + + /// Minimum serialized byte length for Bitlist(N) (empty bitlist: one byte with sentinel). + pub fn minInLength() usize { + return 1; + } + pub fn init(allocator: Allocator) !Self { return .{ .inner = .empty, .allocator = allocator, .length = 0 }; }