Skip to content

Commit

Permalink
Refine the borsh implementation (#45)
Browse files Browse the repository at this point in the history
* Fix the build

* Bump dependencies

* Rewrite borsh tests to catch more errors

In particular, test at byte boundaries and try some unusual (large) uints

* Replace vec! from deserialization with static array

* Simplify borsh serialize and deserialize

* Simplify BorshSerialize trait boundsy
  • Loading branch information
danlehmann authored Jul 24, 2024
1 parent c99730e commit 0a54994
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 70 deletions.
8 changes: 4 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ borsh = ["dep:borsh"]
schemars = ["dep:schemars", "std"]

[dependencies]
num-traits = { version = "0.2.17", default-features = false, optional = true }
defmt = { version = "0.3.5", optional = true }
serde = { version = "1.0", optional = true, default-features = false}
num-traits = { version = "0.2.19", default-features = false, optional = true }
defmt = { version = "0.3.8", optional = true }
serde = { version = "1.0", optional = true, default-features = false }
borsh = { version = "1.5.1", optional = true, features = ["unstable__schema"], default-features = false }
schemars = { version = "0.8.1", optional = true, features = ["derive"], default-features = false }
schemars = { version = "0.8.21", optional = true, features = ["derive"], default-features = false }

[dev-dependencies]
serde_test = "1.0"
58 changes: 25 additions & 33 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ use core::ops::{
#[cfg(feature = "serde")]
use serde::{Deserialize, Deserializer, Serialize, Serializer};

#[cfg(feature = "borsh")]
use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};

#[cfg(all(feature = "borsh", not(feature = "std")))]
use alloc::{collections::BTreeMap, string::ToString};

Expand Down Expand Up @@ -1069,51 +1066,46 @@ where
}
}

// Borsh is byte-size little-endian de-needs-external-schema no-bit-compression serde.
// Current ser/de for it is not optimal impl because const math is not stable nor primitives has bits traits.
// Uses minimal amount of bytes to fit needed amount of bits without compression (borsh does not have it anyway).
#[cfg(feature = "borsh")]
impl<T, const BITS: usize> BorshSerialize for UInt<T, BITS>
impl<T, const BITS: usize> borsh::BorshSerialize for UInt<T, BITS>
where
Self: Number,
T: BorshSerialize
+ From<u8>
+ BitAnd<T, Output = T>
+ TryInto<u8>
+ Copy
+ Shr<usize, Output = T>,
<UInt<T, BITS> as Number>::UnderlyingType:
Shr<usize, Output = T> + TryInto<u8> + From<u8> + BitAnd<T>,
T: borsh::BorshSerialize,
{
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
let value = self.value();
let length = (BITS + 7) / 8;
let mut bytes = 0;
let mask: T = u8::MAX.into();
while bytes < length {
let le_byte: u8 = ((value >> (bytes << 3)) & mask)
.try_into()
.ok()
.expect("we cut to u8 via mask");
writer.write(&[le_byte])?;
bytes += 1;
}
let serialized_byte_count = (BITS + 7) / 8;
let mut buffer = [0u8; 16];
self.value.serialize(&mut &mut buffer[..])?;
writer.write(&buffer[0..serialized_byte_count])?;

Ok(())
}
}

#[cfg(feature = "borsh")]
impl<
T: BorshDeserialize + core::cmp::PartialOrd<<UInt<T, BITS> as Number>::UnderlyingType>,
T: borsh::BorshDeserialize + PartialOrd<<UInt<T, BITS> as Number>::UnderlyingType>,
const BITS: usize,
> BorshDeserialize for UInt<T, BITS>
> borsh::BorshDeserialize for UInt<T, BITS>
where
Self: Number,
{
fn deserialize_reader<R: borsh::io::Read>(reader: &mut R) -> borsh::io::Result<Self> {
let mut buf = vec![0u8; core::mem::size_of::<T>()];
reader.read(&mut buf)?;
let value = T::deserialize(&mut &buf[..])?;
// Ideally, we'd want a buffer of size `BITS >> 3` or `size_of::<T>`, but that's not possible
// with arrays at present (feature(generic_const_exprs), once stable, will allow this).
// vec! would be an option, but an allocation is not expected at this level.
// Therefore, allocate a 16 byte buffer and take a slice out of it.
let serialized_byte_count = (BITS + 7) / 8;
let underlying_byte_count = core::mem::size_of::<T>();
let mut buf = [0u8; 16];

// Read from the source, advancing cursor by the exact right number of bytes
reader.read(&mut buf[..serialized_byte_count])?;

// Deserialize the underlying type. We have to pass in the correct number of bytes of the
// underlying type (or more, but let's be precise). The unused bytes are all still zero
let value = T::deserialize(&mut &buf[..underlying_byte_count])?;

if value >= Self::MIN.value() && value <= Self::MAX.value() {
Ok(Self { value })
} else {
Expand All @@ -1126,7 +1118,7 @@ where
}

#[cfg(feature = "borsh")]
impl<T, const BITS: usize> BorshSchema for UInt<T, BITS> {
impl<T, const BITS: usize> borsh::BorshSchema for UInt<T, BITS> {
fn add_definitions_recursively(
definitions: &mut BTreeMap<borsh::schema::Declaration, borsh::schema::Definition>,
) {
Expand Down
138 changes: 105 additions & 33 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1912,42 +1912,114 @@ fn serde() {
);
}

#[cfg(all(feature = "borsh", feature = "std"))]
#[test]
fn borsh() {
#[cfg(feature = "borsh")]
mod borsh_tests {
use arbitrary_int::{u1, u14, u15, u6, u63, u65, u7, u72, u79, u80, u81, u9, Number, UInt};
use borsh::schema::BorshSchemaContainer;
use borsh::{BorshDeserialize, BorshSerialize};
let mut buf = Vec::new();
let base_input: u8 = 42;
let input = u9::new(base_input.into());
input.serialize(&mut buf).unwrap();
let output = u9::deserialize(&mut buf.as_ref()).unwrap();
let fits = u16::new(base_input.into());
assert_eq!(buf, fits.to_le_bytes());
assert_eq!(input, output);

let input = u63::MAX;
let fits = u64::new(input.value());
let mut buf = Vec::new();
input.serialize(&mut buf).unwrap();
let output: u63 = u63::deserialize(&mut buf.as_ref()).unwrap();
assert_eq!(buf, fits.to_le_bytes());
assert_eq!(input, output);

let schema = BorshSchemaContainer::for_type::<u9>();
match schema.get_definition("u9").expect("exists") {
borsh::schema::Definition::Primitive(2) => {}
_ => panic!("unexpected schema"),
use borsh::{BorshDeserialize, BorshSchema, BorshSerialize};
use std::fmt::Debug;

fn test_roundtrip<T: Number + BorshSerialize + BorshDeserialize + PartialEq + Eq + Debug>(
input: T,
expected_buffer: &[u8],
) {
let mut buf = Vec::new();

// Serialize and compare against expected
input.serialize(&mut buf).unwrap();
assert_eq!(buf, expected_buffer);

// Add to the buffer a second time - this is a better test for the deserialization
// as it ensures we request the correct number of bytes
input.serialize(&mut buf).unwrap();

// Deserialize back and compare against input
let output = T::deserialize(&mut buf.as_ref()).unwrap();
let output2 = T::deserialize(&mut &buf[buf.len() / 2..]).unwrap();
assert_eq!(input, output);
assert_eq!(input, output2);
}

#[test]
fn test_serialize_deserialize() {
// Run against plain u64 first (not an arbitrary_int)
test_roundtrip(
0x12345678_9ABCDEF0u64,
&[0xF0, 0xDE, 0xBC, 0x9A, 0x78, 0x56, 0x34, 0x12],
);

// Now try various arbitrary ints
test_roundtrip(u1::new(0b0), &[0]);
test_roundtrip(u1::new(0b1), &[1]);
test_roundtrip(u6::new(0b101101), &[0b101101]);
test_roundtrip(u14::new(0b110101_11001101), &[0b11001101, 0b110101]);
test_roundtrip(
u72::new(0x36_01234567_89ABCDEF),
&[0xEF, 0xCD, 0xAB, 0x89, 0x67, 0x45, 0x23, 0x01, 0x36],
);

// Pick a byte boundary (80; test one below and one above to ensure we get the right number
// of bytes)
test_roundtrip(
u79::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F],
);
test_roundtrip(
u80::MAX,
&[0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF],
);
test_roundtrip(
u81::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01,
],
);

// Test actual u128 and arbitrary u128 (which is a legal one, though not a predefined)
test_roundtrip(
u128::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF,
],
);
test_roundtrip(
UInt::<u128, 128>::MAX,
&[
0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF,
0xFF, 0xFF,
],
);
}

fn verify_byte_count_in_schema<T: BorshSchema + ?Sized>(expected_byte_count: u8, name: &str) {
let schema = BorshSchemaContainer::for_type::<T>();
match schema.get_definition(name).expect("exists") {
borsh::schema::Definition::Primitive(byte_count) => {
assert_eq!(*byte_count, expected_byte_count);
}
_ => panic!("unexpected schema"),
}
}

let input = u50::MAX;
let fits = u64::new(input.value());
let mut buf = Vec::new();
input.serialize(&mut buf).unwrap();
assert!(buf.len() < fits.to_le_bytes().len());
assert_eq!(buf, fits.to_le_bytes()[0..((u50::BITS + 7) / 8)]);
let output: u50 = u50::deserialize(&mut buf.as_ref()).unwrap();
assert_eq!(input, output);
#[test]
fn test_schema_byte_count() {
verify_byte_count_in_schema::<u1>(1, "u1");

verify_byte_count_in_schema::<u7>(1, "u7");

verify_byte_count_in_schema::<UInt<u8, 8>>(1, "u8");
verify_byte_count_in_schema::<UInt<u32, 8>>(1, "u8");

verify_byte_count_in_schema::<u9>(2, "u9");

verify_byte_count_in_schema::<u15>(2, "u15");
verify_byte_count_in_schema::<UInt<u128, 15>>(2, "u15");

verify_byte_count_in_schema::<u63>(8, "u63");

verify_byte_count_in_schema::<u65>(9, "u65");
}
}

#[cfg(feature = "schemars")]
Expand Down

0 comments on commit 0a54994

Please sign in to comment.