Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/beacon_tests.zig
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ test "hashTreeRoot for pointer types" {
try hashTreeRoot(Sha256, *u32, &value, &hash, std.testing.allocator);

var deserialized: u32 = undefined;
try deserialize(u32, &hash, &deserialized, std.testing.allocator);
try deserialize(u32, hash[0..4], &deserialized, std.testing.allocator);
try expect(deserialized == value);
}

Expand All @@ -790,7 +790,7 @@ test "hashTreeRoot for pointer types" {
try hashTreeRoot(Sha256, *[4]u8, values_ptr, &hash, std.testing.allocator);

var deserialized: [4]u8 = undefined;
try deserialize([4]u8, &hash, &deserialized, std.testing.allocator);
try deserialize([4]u8, hash[0..4], &deserialized, std.testing.allocator);
try expect(std.mem.eql(u8, &deserialized, values_ptr));
}

Expand Down
128 changes: 123 additions & 5 deletions src/lib.zig
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,107 @@ pub fn isFixedSizeObject(T: type) !bool {
return true;
}

/// Returns the maximum possible serialized byte length for type `T`.
/// Useful for pre-allocating buffers or validating input bounds.
/// For variable-length types (e.g. slice), use a type that encodes max length (e.g. List(T, N)) or returns error.
pub fn maxInLength(T: type) !usize {
if (comptime std.meta.hasFn(T, "maxInLength")) {
return T.maxInLength();
}

const info = @typeInfo(T);
return switch (info) {
.int => @sizeOf(T),
.bool => @as(usize, 1),
.null => @as(usize, 0),
.array => |array| if (array.child == bool)
(array.len + 7) / 8
else blk: {
const child_max = try maxInLength(array.child);
if (try isFixedSizeObject(array.child)) {
break :blk array.len * child_max;
} else {
break :blk array.len * child_max + 4 * array.len;
}
},
.optional => 1 + try maxInLength(info.optional.child),
.pointer => |ptr| switch (ptr.size) {
.slice => error.NoMaxInLengthAvailable,
.one => maxInLength(ptr.child),
else => error.NoMaxInLengthAvailable,
},
.@"struct" => |str| blk: {
var total: usize = 0;
inline for (str.fields) |field| {
if (try isFixedSizeObject(field.type)) {
total += try maxInLength(field.type);
} else {
total += 4 + try maxInLength(field.type);
}
}
break :blk total;
},
.@"union" => |u| blk: {
if (u.tag_type == null) return error.UnionIsNotTagged;
var m: usize = 0;
inline for (u.fields) |f| {
const n = try maxInLength(f.type);
if (n > m) m = n;
}
break :blk 1 + m;
},
else => error.NoMaxInLengthAvailable,
};
}

/// Returns the minimum possible serialized byte length for type `T`.
/// Used together with maxInLength to validate input bounds before deserializing.
pub fn minInLength(T: type) !usize {
if (comptime std.meta.hasFn(T, "minInLength")) {
return T.minInLength();
}

const info = @typeInfo(T);
return switch (info) {
.int => @sizeOf(T),
.bool => @as(usize, 1),
.null => @as(usize, 0),
.array => |array| if (array.child == bool)
(array.len + 7) / 8
else if (try isFixedSizeObject(array.child))
array.len * try minInLength(array.child)
else
array.len * @sizeOf(u32) + array.len * try minInLength(array.child),
.optional => 1,
.pointer => |ptr| switch (ptr.size) {
.slice => error.NoMinInLengthAvailable,
.one => minInLength(ptr.child),
else => error.NoMinInLengthAvailable,
},
.@"struct" => |str| blk: {
var total: usize = 0;
inline for (str.fields) |field| {
if (try isFixedSizeObject(field.type)) {
total += try minInLength(field.type);
} else {
total += 4 + try minInLength(field.type);
}
}
break :blk total;
},
.@"union" => |u| blk: {
if (u.tag_type == null) return error.UnionIsNotTagged;
var m: usize = std.math.maxInt(usize);
inline for (u.fields) |f| {
const n = try minInLength(f.type);
if (n < m) m = n;
}
break :blk 1 + m;
},
else => error.NoMinInLengthAvailable,
};
}

/// Provides the generic serialization of any `data` var to SSZ. The
/// serialization is written to the `ArrayList` `l`.
pub fn serialize(T: type, data: T, l: *ArrayList(u8), allocator: Allocator) !void {
Expand Down Expand Up @@ -312,12 +413,29 @@ pub fn serialize(T: type, data: T, l: *ArrayList(u8), allocator: Allocator) !voi
}
}

/// Takes a byte array containing the serialized payload of type `T` (with
/// possible trailing data) and deserializes it into the `T` object pointed
/// at by `out`.
/// Takes a byte array containing the serialized payload of type `T` and
/// deserializes it into the `T` object pointed at by `out`.
/// The payload must be within [minInLength, maxInLength] bounds for `T`.
pub fn deserialize(T: type, serialized: []const u8, out: *T, allocator: ?Allocator) !void {
const has_custom_decode = comptime std.meta.hasFn(T, "sszDecode");
const enforce_min = !has_custom_decode or comptime std.meta.hasFn(T, "minInLength");
const enforce_max = !has_custom_decode or comptime std.meta.hasFn(T, "maxInLength");

// Bounds check: ensure serialized length is within [minInLength, maxInLength]
const min_len: ?usize = if (enforce_min) blk: {
const m = minInLength(T) catch break :blk null;
break :blk m;
} else null;
if (min_len) |m| if (serialized.len < m) return error.PayloadTooSmall;

const max_len: ?usize = if (enforce_max) blk: {
const m = maxInLength(T) catch break :blk null;
break :blk m;
} else null;
if (max_len) |m| if (serialized.len > m) return error.PayloadTooLarge;

// shortcut if the type implements its own decode method
if (comptime std.meta.hasFn(T, "sszDecode")) {
if (has_custom_decode) {
return T.sszDecode(serialized, out, allocator);
}

Expand Down Expand Up @@ -497,7 +615,7 @@ pub fn deserialize(T: type, serialized: []const u8, out: *T, allocator: ?Allocat
.@"union" => {
// Read the type index
var union_index: u8 = undefined;
try deserialize(u8, serialized, &union_index, allocator);
try deserialize(u8, serialized[0..1], &union_index, allocator);

// Use the index to figure out which type must
// be deserialized.
Expand Down
84 changes: 83 additions & 1 deletion src/tests.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1121,6 +1121,88 @@ test "isFixedSizeObject correctly identifies List/Bitlist as variable-size" {
try expect(!try isFixedSizeObject(StructWithList));
}

test "maxInLength for fixed and variable types" {
try expect(try libssz.maxInLength(u8) == 1);
try expect(try libssz.maxInLength(u64) == 8);
try expect(try libssz.maxInLength(bool) == 1);
try expect(try libssz.maxInLength([4]u8) == 4);
try expect(try libssz.maxInLength([10]bool) == (10 + 7) / 8);

const ListU64 = utils.List(u64, 16);
try expect(try ListU64.maxInLength() == 16 * 8);

const Bitlist32 = utils.Bitlist(32);
try expect(Bitlist32.maxInLength() == (32 + 7 + 1) / 8);

const ListList = utils.List(utils.List(u8, 4), 2);
try expect(try ListList.maxInLength() == 2 * 4 + 2 * (4 * 1));

const S = struct {
a: u32,
b: [2]u8,
};
try expect(try libssz.maxInLength(S) == 4 + 2);
}

test "minInLength for fixed and variable types" {
try expect(try libssz.minInLength(u8) == 1);
try expect(try libssz.minInLength(u64) == 8);
try expect(try libssz.minInLength(bool) == 1);
try expect(try libssz.minInLength([4]u8) == 4);
try expect(try libssz.minInLength([10]bool) == (10 + 7) / 8);

const ListU64 = utils.List(u64, 16);
try expect(ListU64.minInLength() == 0);

const Bitlist32 = utils.Bitlist(32);
try expect(Bitlist32.minInLength() == 1);

const S = struct {
a: u32,
b: [2]u8,
};
try expect(try libssz.minInLength(S) == 4 + 2);

const VarS = struct {
a: u32,
b: []const u8,
};
_ = libssz.minInLength(VarS) catch |e| try expect(e == error.NoMinInLengthAvailable);
}

test "deserialize rejects payload shorter than minInLength" {
var out_u32: u32 = undefined;
try expectError(error.PayloadTooSmall, deserialize(u32, &[_]u8{ 0x01, 0x02 }, &out_u32, null));

var out_bool: bool = undefined;
try expectError(error.PayloadTooSmall, deserialize(bool, &[_]u8{}, &out_bool, null));

var out_fixed: [4]u8 = undefined;
try expectError(error.PayloadTooSmall, deserialize([4]u8, &[_]u8{ 0x01, 0x02 }, &out_fixed, null));
}

test "minInLength/maxInLength for struct with List field" {
const S = struct {
id: u32,
data: utils.List(u8, 8),
};
// min: 4 (u32) + 4 (offset for variable field) + 0 (empty list) = 8
try expect(try libssz.minInLength(S) == 4 + 4 + 0);
// max: 4 (u32) + 4 (offset for variable field) + 8*1 (full list) = 16
try expect(try libssz.maxInLength(S) == 4 + 4 + 8 * 1);
}

test "deserialize rejects payload longer than maxInLength" {
var out_u32: u32 = undefined;
try expectError(error.PayloadTooLarge, deserialize(u32, &[_]u8{ 0x01, 0x02, 0x03, 0x04, 0x05 }, &out_u32, null));

var out_bool: bool = undefined;
try expectError(error.PayloadTooLarge, deserialize(bool, &[_]u8{ 0x00, 0x01 }, &out_bool, null));

var out_fixed: [2]u8 = undefined;
try expectError(error.PayloadTooLarge, deserialize([2]u8, &[_]u8{ 0x01, 0x02, 0x03 }, &out_fixed, null));
}

test "zeam stf input" {
const Bytes32 = [32]u8;
const Bytes48 = [48]u8;
Expand Down Expand Up @@ -1682,7 +1764,7 @@ test "List validation - size limits enforced" {
0x05, 0x00, 0x00, 0x00, // u32 = 5
};

try std.testing.expectError(error.ListTooBig, deserialize(utils.List(u32, 3), &oversized_data, &list, std.testing.allocator));
try std.testing.expectError(error.PayloadTooLarge, deserialize(utils.List(u32, 3), &oversized_data, &list, std.testing.allocator));
}
}

Expand Down
28 changes: 23 additions & 5 deletions src/utils.zig
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ pub fn List(T: type, comptime N: usize) type {
return false;
}

/// Maximum serialized byte length for List(T, N) with at most N elements.
pub fn maxInLength() !usize {
if (try lib.isFixedSizeObject(Item)) {
return N * try lib.serializedFixedSize(Item);
}
return N * @sizeOf(u32) + N * try lib.maxInLength(Item);
}

/// Minimum serialized byte length for List(T, N) (empty list).
pub fn minInLength() usize {
return 0;
}

pub fn sszDecode(serialized: []const u8, out: *Self, allocator: ?Allocator) !void {
// BitList[N] or regular List[N]?
const alloc = allocator orelse return error.AllocatorRequired;
Expand All @@ -60,11 +73,6 @@ pub fn List(T: type, comptime N: usize) type {
const pitch = try lib.serializedFixedSize(Self.Item);
const n_items = serialized.len / pitch;

// Validate list size against maximum N
if (n_items > N) {
return error.ListTooBig;
}

for (0..n_items) |i| {
var item: Self.Item = undefined;
try deserialize(Self.Item, serialized[i * pitch .. (i + 1) * pitch], &item, allocator);
Expand Down Expand Up @@ -283,6 +291,16 @@ pub fn Bitlist(comptime N: usize) type {
return false;
}

/// Maximum serialized byte length for Bitlist(N) (N bits + sentinel).
pub fn maxInLength() usize {
return (N + 7 + 1) / 8;
}

/// Minimum serialized byte length for Bitlist(N) (empty bitlist: one byte with sentinel).
pub fn minInLength() usize {
return 1;
}

pub fn init(allocator: Allocator) !Self {
return .{ .inner = .empty, .allocator = allocator, .length = 0 };
}
Expand Down