diff --git a/Cargo.toml b/Cargo.toml index 64af089..694fde2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ serde = ["encode", "dep:serde"] [dependencies] fxhash = "0.2" bytes = { version = "1", optional = true } -serde = { version = "1", optional = true } +serde = { version = "1", optional = true, features = ["derive"] } [target.'cfg(any(target_arch = "x86", target_arch = "x86_64"))'.dependencies] varint-simd = { version = "0.4", optional = true } @@ -55,11 +55,13 @@ varint-simd = { version = "0.4", optional = true } unsigned-varint = { version = "0.8", optional = true } [dev-dependencies] +bincode = { version = "2", features = ["serde"] } criterion = "0.7" pando = { path = ".", features = ["encode", "serde"] } pando-macros = { path = "macros" } rand = "0.9" rand_chacha = "0.9" +serde_json = "1" [[bench]] name = "apply_local" @@ -74,5 +76,10 @@ name = "encode" harness = false required-features = ["encode"] +[[bench]] +name = "serde" +harness = false +required-features = ["serde"] + [lints] workspace = true diff --git a/benches/serde.rs b/benches/serde.rs new file mode 100644 index 0000000..35ac5dd --- /dev/null +++ b/benches/serde.rs @@ -0,0 +1,185 @@ +#![allow(missing_docs)] +#![allow(clippy::unwrap_used)] + +mod common; + +use core::fmt; + +use common::TreeSize; +use criterion::{BenchmarkId, Criterion}; +use rand::SeedableRng; +use serde::de; + +fn bench_serialize(c: &mut Criterion) { + let mut rng = ::seed_from_u64(42); + + let mut group = + c.benchmark_group(format!("serialize_{}_{}", D::NAME, M::NAME)); + + for test_vector in TestVector::iter() { + let tree = test_vector.tree_size.to_tree::(&mut rng); + group.bench_function(test_vector.bench_id(), |bencher| { + bencher.iter(|| { + D::serialize( + &tree + .serialize() + .with_tree_state(test_vector.encode_tree_state), + ); + }) + }); + } +} + +fn bench_deserialize(c: &mut Criterion) { + let mut rng = ::seed_from_u64(42); + + let mut group = + c.benchmark_group(format!("deserialize_{}_{}", D::NAME, M::NAME)); + + for test_vector in TestVector::iter() { + let tree = test_vector.tree_size.to_tree::(&mut rng); + let serialized = D::serialize( + &tree.serialize().with_tree_state(test_vector.encode_tree_state), + ); + group.bench_function(test_vector.bench_id(), |bencher| { + bencher.iter(|| { + D::deserialize_seed( + pando::Tree::::deserialize(2), + &serialized, + ); + }) + }); + } +} + +trait DataFormat { + const NAME: &str; + + type Serialized; + + fn serialize(value: &impl serde::Serialize) -> Self::Serialized; + + fn deserialize_seed<'de, T: de::DeserializeSeed<'de>>( + seed: T, + serialized: &'de Self::Serialized, + ) -> T::Value; +} + +trait Metadata: + Default + + Clone + + fmt::Debug + + Eq + + serde::Serialize + + serde::de::DeserializeOwned +{ + const NAME: &str; +} + +struct Bincode; + +#[derive( + Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize, +)] +#[serde(transparent)] +struct LargeString { + inner: String, +} + +struct TestVector { + encode_tree_state: bool, + tree_size: TreeSize, +} + +impl TestVector { + fn iter() -> impl Iterator { + let mut items = Vec::new(); + + for tree_size in TreeSize::iter() { + for encode_tree_state in [true, false] { + items.push(Self { encode_tree_state, tree_size }); + } + } + + items.into_iter() + } + + fn bench_id(&self) -> BenchmarkId { + BenchmarkId::new( + self.tree_size, + if self.encode_tree_state { + "with_state" + } else { + "without_state" + }, + ) + } +} + +impl Metadata for () { + const NAME: &str = "unit"; +} + +impl Metadata for LargeString { + const NAME: &str = "large_string"; +} + +impl Default for LargeString { + fn default() -> Self { + Self { inner: "x".repeat(SIZE) } + } +} + +impl DataFormat for Bincode { + const NAME: &str = "bincode"; + + type Serialized = Vec; + + fn serialize(value: &impl serde::Serialize) -> Self::Serialized { + bincode::serde::encode_to_vec(value, bincode::config::standard()) + .expect("failed to serialize with bincode") + } + + fn deserialize_seed<'de, T: de::DeserializeSeed<'de>>( + seed: T, + serialized: &'de Self::Serialized, + ) -> T::Value { + let (deserialized, _num_read) = + bincode::serde::seed_decode_from_slice( + seed, + serialized, + bincode::config::standard(), + ) + .expect("failed to deserialize with bincode"); + + deserialized + } +} + +fn serialize_unit_bincode(c: &mut Criterion) { + bench_serialize::(c); +} + +fn serialize_large_string_bincode(c: &mut Criterion) { + bench_serialize::(c); +} + +fn decode_unit_bincode(c: &mut Criterion) { + bench_deserialize::(c); +} + +fn decode_large_string_bincode(c: &mut Criterion) { + bench_deserialize::(c); +} + +criterion::criterion_group!( + encode, + serialize_unit_bincode, + serialize_large_string_bincode +); +criterion::criterion_group!( + decode, + decode_unit_bincode, + decode_large_string_bincode +); +criterion::criterion_main!(encode, decode); diff --git a/src/forest.rs b/src/forest.rs index ae98112..4647c4e 100644 --- a/src/forest.rs +++ b/src/forest.rs @@ -6,9 +6,10 @@ use crate::node_id::{GlobalNodeId, LocalNodeId}; use crate::op_log; #[derive(Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct Forest { /// TODO: docs. - forest: Vec, + nodes: Vec, /// Keeps track of the nodes that cannot be added to the visible tree /// because we haven't yet received the creation of their parent. @@ -20,6 +21,7 @@ pub(crate) struct Forest { } #[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub(crate) struct Node { /// TODO: docs. children: Vec, @@ -32,6 +34,8 @@ pub(crate) struct Node { } #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub(crate) struct NodeKey(usize); #[derive(Debug, PartialEq, Eq)] @@ -42,6 +46,7 @@ pub(crate) enum NodeState { } #[derive(Clone, Default)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] struct Backlog { /// Map from the [`GlobalNodeId`] of a node whose creation we haven't yet /// received to the [`NodeKey`]s of all the nodes in the forest that would @@ -158,7 +163,7 @@ impl Forest { let mut id_to_key = FxHashMap::default(); id_to_key.insert(root_id, Self::VISIBLE_ROOT_KEY); - Self { forest: vec![root], backlog: Backlog::default(), id_to_key } + Self { nodes: vec![root], backlog: Backlog::default(), id_to_key } } /// TODO: docs. @@ -241,10 +246,10 @@ impl Forest { ) -> NodeKey { debug_assert!(!self.id_to_key.contains_key(&node_id)); - let node_key = NodeKey::new(self.forest.len()); + let node_key = NodeKey::new(self.nodes.len()); // Insert the new node into the forest. - self.forest.push(Node { + self.nodes.push(Node { children: Vec::new(), creation_key: node_creation_key, parent_key, @@ -433,7 +438,7 @@ impl ops::Index for Forest { #[track_caller] #[inline] fn index(&self, key: NodeKey) -> &Self::Output { - &self.forest[key.into_usize()] + &self.nodes[key.into_usize()] } } @@ -441,7 +446,7 @@ impl ops::IndexMut for Forest { #[track_caller] #[inline] fn index_mut(&mut self, key: NodeKey) -> &mut Self::Output { - &mut self.forest[key.into_usize()] + &mut self.nodes[key.into_usize()] } } @@ -502,6 +507,10 @@ pub(crate) mod udr { /// TODO: docs. #[derive(Debug, Copy, Clone, PartialEq, Eq)] + #[cfg_attr( + feature = "serde", + derive(serde::Serialize, serde::Deserialize) + )] pub(crate) enum PrevParent { /// TODO: docs. Backlogged, @@ -741,7 +750,7 @@ mod invariants { #[allow(dead_code)] #[track_caller] pub(crate) fn assert_invariants(&self) { - let Some(visible_root) = self.forest.first() else { + let Some(visible_root) = self.nodes.first() else { assert!(self.backlog.is_empty()); assert!(self.id_to_key.is_empty()); return; @@ -756,9 +765,9 @@ mod invariants { .copied() .expect("tree not empty"); - assert_eq!(max_node_key.into_usize() + 1, self.forest.len()); + assert_eq!(max_node_key.into_usize() + 1, self.nodes.len()); - for node_key in (0..self.forest.len()).map(NodeKey::new) { + for node_key in (0..self.nodes.len()).map(NodeKey::new) { self.assert_node_invariants(node_key); } @@ -782,7 +791,7 @@ mod invariants { // If the node is not the root of its subtree, make sure its parent // is valid. if node.parent_key.is_some() { - assert!(node.parent_key.into_usize() < self.forest.len()); + assert!(node.parent_key.into_usize() < self.nodes.len()); } // Make sure the node's children are valid. @@ -794,7 +803,7 @@ mod invariants { ); assert!( - child_key.into_usize() < self.forest.len(), + child_key.into_usize() < self.nodes.len(), "the node at {node_key:?} has a child with an oob key: \ {child_key:?}" ); @@ -913,7 +922,7 @@ mod debug { } fn iter(&self) -> impl Iterator { - self.forest + self.nodes .iter() .enumerate() .map(|(idx, node)| (node, NodeKey::new(idx))) @@ -1010,7 +1019,7 @@ mod encode { impl Encode for Forest { #[inline] fn encode(&self, buf: &mut impl Buffer) { - self.forest.encode(buf); + self.nodes.encode(buf); self.backlog.encode(buf); self.id_to_key.encode(buf); } @@ -1024,7 +1033,7 @@ mod encode { let forest = Vec::decode(buf)?; let backlog = Backlog::decode(buf)?; let id_to_key = FxHashMap::decode(buf)?; - Ok(Self { forest, backlog, id_to_key }) + Ok(Self { nodes: forest, backlog, id_to_key }) } } diff --git a/src/gtree.rs b/src/gtree.rs index 2599c1a..0e53428 100644 --- a/src/gtree.rs +++ b/src/gtree.rs @@ -10,6 +10,15 @@ pub(crate) trait Summarize { } #[derive(Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr( + feature = "serde", + serde(bound( + serialize = "T: serde::Serialize, T::Summary: serde::Serialize", + deserialize = "T: serde::Deserialize<'de>, T::Summary: \ + serde::Deserialize<'de>", + )) +)] pub(crate) struct Gtree { /// The internal nodes of the tree. inodes: Vec>, @@ -26,9 +35,11 @@ pub(crate) struct Gtree { } #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub(crate) struct Key(LeafIdx); -#[derive(Clone)] +#[derive(Debug, Clone)] struct Inode { /// The indices of the children of this inode in the [`Gtree`]. The first /// [`num_children`](Self::num_children) items are valid, the rest are @@ -53,7 +64,8 @@ struct Inode { parent_idx: InodeIdx, } -#[derive(Clone)] +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] struct Leaf { /// The item stored in this leaf. item: T, @@ -63,14 +75,20 @@ struct Leaf { } #[derive(Copy, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] #[repr(transparent)] struct InodeIdx(NodeIdx); #[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] #[repr(transparent)] struct LeafIdx(NodeIdx); #[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] struct NodeIdx(usize); enum Children<'inode> { @@ -276,6 +294,20 @@ where } } +impl PartialEq for Gtree +where + T: PartialEq, +{ + #[inline] + fn eq(&self, other: &Self) -> bool { + self.len() == other.len() + && self + .iter_forward() + .zip(other.iter_forward()) + .all(|(lhs, rhs)| *lhs == *rhs) + } +} + impl ops::Index for Gtree { type Output = T; @@ -1654,6 +1686,266 @@ mod gtree_encode { } } +#[cfg(feature = "serde")] +mod serde_impls { + use core::marker::PhantomData; + + use serde::de; + use serde::ser::{Serialize, SerializeStruct}; + + use super::*; + + const INODE_NAME: &str = "Inode"; + + #[derive(serde::Deserialize)] + #[serde(field_identifier, rename_all = "snake_case")] + enum InodeField { + Children, + ChildrenAreLeaves, + MaxSummary, + ParentIdx, + } + + struct InodeVisitor(PhantomData); + + impl InodeField { + const AS_SLICE: &'static [&'static str] = &[ + Self::Children.as_str(), + Self::ChildrenAreLeaves.as_str(), + Self::MaxSummary.as_str(), + Self::ParentIdx.as_str(), + ]; + + #[inline] + const fn as_str(&self) -> &'static str { + match self { + Self::Children => "children", + Self::ChildrenAreLeaves => "children_are_leaves", + Self::MaxSummary => "max_summary", + Self::ParentIdx => "parent_idx", + } + } + } + + impl Serialize for Inode + where + T::Summary: Serialize, + { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut inode = serializer + .serialize_struct(INODE_NAME, InodeField::AS_SLICE.len())?; + + inode.serialize_field( + InodeField::Children.as_str(), + &self.children[..self.num_children as usize], + )?; + + inode.serialize_field( + InodeField::ChildrenAreLeaves.as_str(), + &self.children_are_leaves, + )?; + + inode.serialize_field( + InodeField::MaxSummary.as_str(), + &self.max_summary, + )?; + + inode.serialize_field( + InodeField::ParentIdx.as_str(), + &self.parent_idx, + )?; + + inode.end() + } + } + + impl<'de, T: Summarize, const FANOUT: usize> de::Deserialize<'de> + for Inode + where + T::Summary: de::Deserialize<'de>, + { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_struct( + INODE_NAME, + InodeField::AS_SLICE, + InodeVisitor(PhantomData), + ) + } + } + + impl<'de, T: Summarize, const FANOUT: usize> de::Visitor<'de> + for InodeVisitor + where + T::Summary: de::Deserialize<'de>, + { + type Value = Inode; + + #[inline] + fn expecting( + &self, + formatter: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + write!(formatter, "a map representing an {INODE_NAME}") + } + + #[inline] + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut children: Option> = None; + let mut children_are_leaves: Option = None; + let mut max_summary: Option = None; + let mut parent_idx: Option = None; + + while let Some(inode_field) = map.next_key::()? { + match inode_field { + InodeField::Children => { + children = Some(map.next_value()?); + }, + InodeField::ChildrenAreLeaves => { + children_are_leaves = Some(map.next_value()?); + }, + InodeField::MaxSummary => { + max_summary = Some(map.next_value()?); + }, + InodeField::ParentIdx => { + parent_idx = Some(map.next_value()?); + }, + } + } + + let Children(children, num_children) = + children.ok_or_else(|| { + de::Error::missing_field(InodeField::Children.as_str()) + })?; + + let children_are_leaves = + children_are_leaves.ok_or_else(|| { + de::Error::missing_field( + InodeField::ChildrenAreLeaves.as_str(), + ) + })?; + + let max_summary = max_summary.ok_or_else(|| { + de::Error::missing_field(InodeField::MaxSummary.as_str()) + })?; + + let parent_idx = parent_idx.ok_or_else(|| { + de::Error::missing_field(InodeField::ParentIdx.as_str()) + })?; + + Ok(Inode { + children, + children_are_leaves, + max_summary, + num_children, + parent_idx, + }) + } + + #[inline] + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let Children(children, num_children) = seq + .next_element::>()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + + let children_are_leaves = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + let max_summary = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + + let parent_idx = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(3, &self))?; + + if seq.next_element::()?.is_some() { + return Err(de::Error::invalid_length( + InodeField::AS_SLICE.len() + 1, + &self, + )); + } + + Ok(Inode { + children, + children_are_leaves, + max_summary, + num_children, + parent_idx, + }) + } + } + + struct Children([NodeIdx; N], Fanout); + + impl<'de, const N: usize> de::Deserialize<'de> for Children { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + struct ChildrenVisitor; + + impl<'de, const N: usize> de::Visitor<'de> for ChildrenVisitor { + type Value = Children; + + #[inline] + fn expecting( + &self, + formatter: &mut fmt::Formatter, + ) -> fmt::Result { + write!(formatter, "a sequence of up to {N} NodeIdxs") + } + + #[inline] + fn visit_seq( + self, + mut seq: A, + ) -> Result + where + A: de::SeqAccess<'de>, + { + let mut array = [NodeIdx::NONE; N]; + + for (array_idx, node_idx) in array.iter_mut().enumerate() { + match seq.next_element()? { + Some(elem) => *node_idx = elem, + None => { + return Ok(Children( + array, + array_idx as Fanout, + )); + }, + } + } + + if seq.next_element::()?.is_some() { + return Err(de::Error::invalid_length(N + 1, &self)); + } + + Ok(Children(array, N as Fanout)) + } + } + + deserializer.deserialize_seq(ChildrenVisitor::) + } + } +} + mod gtree_debug_as_tree { use super::*; use crate::debug; @@ -1912,3 +2204,49 @@ mod gtree_invariants { UnsortedChildren { parent_summary: Summary }, } } + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "serde")] + #[test] + fn gtree_bincode_roundtrip() { + let mut gtree = Gtree::::new(); + + gtree.append(1); + gtree.append(2); + gtree.append(3); + gtree.append(4); + gtree.append(5); + + let config = bincode::config::standard(); + + let serialized = + bincode::serde::encode_to_vec(>ree, config).unwrap(); + + let (deserialized, num_read) = bincode::serde::decode_from_slice::< + Gtree, + _, + >(&serialized, config) + .unwrap(); + + if num_read != serialized.len() { + panic!( + "bincode read {} bytes, expected {}", + num_read, + serialized.len() + ); + } + + assert_eq!(gtree, deserialized); + } + + impl Summarize for u32 { + type Summary = Self; + + fn summarize(&self) -> Self::Summary { + *self + } + } +} diff --git a/src/lamport_clock.rs b/src/lamport_clock.rs index bdbe527..204adca 100644 --- a/src/lamport_clock.rs +++ b/src/lamport_clock.rs @@ -1,11 +1,15 @@ use core::{cmp, fmt}; #[derive(Default, Copy, Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub(crate) struct LamportClock { next_timestamp: u64, } #[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub(crate) struct LamportTimestamp { inner: u64, } @@ -47,8 +51,10 @@ impl LamportTimestamp { /// * 1 bit to indicate whether the timestamp is [`NONE`](Self::NONE); const NUM_RESERVED_BITS: u8 = 1; + #[track_caller] #[inline] pub(crate) const fn into_u64(self) -> u64 { + debug_assert!(!self.is_none()); self.inner >> Self::NUM_RESERVED_BITS } @@ -58,7 +64,7 @@ impl LamportTimestamp { } #[inline] - const fn new(inner: u64) -> Self { + pub(crate) const fn new(inner: u64) -> Self { debug_assert!(inner <= Self::MAX); Self { inner: inner << Self::NUM_RESERVED_BITS } } diff --git a/src/lib.rs b/src/lib.rs index 907d9eb..cb7334d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,9 @@ mod node_id; mod op_log; mod op_timestamp; pub mod ops; +#[cfg(feature = "serde")] +#[cfg_attr(docsrs, doc(cfg(feature = "serde")))] +pub mod serde; mod tree; pub use tree::Tree; diff --git a/src/node_id.rs b/src/node_id.rs index f735655..2132ae5 100644 --- a/src/node_id.rs +++ b/src/node_id.rs @@ -6,6 +6,8 @@ use crate::{PeerId, forest}; /// A globally unique identifier for a node in a [`Tree`](crate::Tree). #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub struct GlobalNodeId { /// We use the timestamp of the operation that created the node as its /// global identifier. @@ -14,6 +16,8 @@ pub struct GlobalNodeId { /// TODO: docs. #[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub struct LocalNodeId { /// We use the [`NodeKey`](forest::NodeKey) of the node in the /// [`Forest`](forest::Forest) as its local identifier. @@ -61,12 +65,7 @@ impl fmt::Debug for GlobalNodeId { if self.is_none() { write!(f, "GlobalNodeId::NONE") } else { - write!( - f, - "GlobalNodeId({}.{})", - self.creation_timestamp.lamport_timestamp.into_u64(), - self.creation_timestamp.created_by - ) + write!(f, "GlobalNodeId({})", self.creation_timestamp.display()) } } } @@ -114,57 +113,4 @@ mod encode { forest::NodeKey::decode(buf).map(|forest_key| Self { forest_key }) } } - - #[cfg(feature = "serde")] - mod serde { - use ::serde::{de, ser}; - - use super::*; - - #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] - impl ser::Serialize for GlobalNodeId { - #[inline] - fn serialize(&self, serializer: S) -> Result - where - S: ser::Serializer, - { - encode::SerdeCompat(self).serialize(serializer) - } - } - - #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] - impl<'de> de::Deserialize<'de> for GlobalNodeId { - #[inline] - fn deserialize(deserializer: D) -> Result - where - D: de::Deserializer<'de>, - { - encode::SerdeCompat::deserialize(deserializer) - .map(|encode::SerdeCompat(this)| this) - } - } - - #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] - impl ser::Serialize for LocalNodeId { - #[inline] - fn serialize(&self, serializer: S) -> Result - where - S: ser::Serializer, - { - encode::SerdeCompat(self).serialize(serializer) - } - } - - #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] - impl<'de> de::Deserialize<'de> for LocalNodeId { - #[inline] - fn deserialize(deserializer: D) -> Result - where - D: de::Deserializer<'de>, - { - encode::SerdeCompat::deserialize(deserializer) - .map(|encode::SerdeCompat(this)| this) - } - } - } } diff --git a/src/op_log.rs b/src/op_log.rs index c194b98..bf8a286 100644 --- a/src/op_log.rs +++ b/src/op_log.rs @@ -19,6 +19,8 @@ const FANOUT: usize = 4; const FANOUT: usize = 32; #[derive(Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub(crate) struct OpLog { gtree: Gtree, FANOUT>, } @@ -44,10 +46,13 @@ pub(crate) struct Op { } #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +#[cfg_attr(feature = "serde", serde(transparent))] pub(crate) struct OpKey(gtree::Key); /// TODO: docs. #[derive(Clone)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] struct LogEntry { /// The parent of the node the [`op`](Self::op) acts on before the op was /// applied. Used when undoing ops to move the node back to its previous @@ -540,3 +545,169 @@ pub(crate) mod encode { } } } + +#[cfg(feature = "serde")] +pub(crate) mod serde_impls { + use core::marker::PhantomData; + + use serde::de; + use serde::ser::{self, SerializeStruct}; + + use super::*; + + const OP_NAME: &str = "Op"; + + #[derive(serde::Deserialize)] + #[serde(field_identifier, rename_all = "snake_case")] + enum OpField { + Metadata, + NewParent, + NodeId, + Timestamp, + } + + struct OpVisitor(PhantomData); + + impl OpField { + const AS_SLICE: &'static [&'static str] = &[ + Self::Metadata.as_str(), + Self::NewParent.as_str(), + Self::NodeId.as_str(), + Self::Timestamp.as_str(), + ]; + + #[inline] + const fn as_str(&self) -> &'static str { + match self { + Self::Metadata => "metadata", + Self::NewParent => "new_parent", + Self::NodeId => "node_id", + Self::Timestamp => "timestamp", + } + } + } + + impl ser::Serialize for Op { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ser::Serializer, + { + let mut op = serializer + .serialize_struct(OP_NAME, OpField::AS_SLICE.len())?; + + op.serialize_field(OpField::Metadata.as_str(), &self.metadata())?; + op.serialize_field(OpField::NewParent.as_str(), &self.new_parent)?; + op.serialize_field(OpField::NodeId.as_str(), &self.node_id)?; + op.serialize_field(OpField::Timestamp.as_str(), &self.timestamp)?; + + op.end() + } + } + + impl<'de, M: de::Deserialize<'de>> de::Deserialize<'de> for Op { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_struct( + OP_NAME, + OpField::AS_SLICE, + OpVisitor(PhantomData), + ) + } + } + + impl<'de, M: de::Deserialize<'de>> de::Visitor<'de> for OpVisitor { + type Value = Op; + + #[inline] + fn expecting( + &self, + formatter: &mut fmt::Formatter<'_>, + ) -> fmt::Result { + write!(formatter, "a map representing an {OP_NAME}") + } + + #[inline] + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut metadata: Option> = None; + let mut new_parent: Option = None; + let mut node_id: Option = None; + let mut timestamp: Option = None; + + while let Some(op_field) = map.next_key::()? { + match op_field { + OpField::Metadata => { + metadata = Some( + map.next_value::>()? + .map(MaybeUninit::new) + .unwrap_or(MaybeUninit::uninit()), + ); + }, + OpField::NewParent => { + new_parent = Some(map.next_value()?); + }, + OpField::NodeId => { + node_id = Some(map.next_value()?); + }, + OpField::Timestamp => { + timestamp = map.next_value()?; + }, + } + } + + let metadata = metadata.ok_or_else(|| { + de::Error::missing_field(OpField::Metadata.as_str()) + })?; + + let new_parent = new_parent.ok_or_else(|| { + de::Error::missing_field(OpField::NewParent.as_str()) + })?; + + let node_id = node_id.ok_or_else(|| { + de::Error::missing_field(OpField::NodeId.as_str()) + })?; + + let timestamp = timestamp.ok_or_else(|| { + de::Error::missing_field(OpField::Timestamp.as_str()) + })?; + + Ok(Op { metadata, new_parent, node_id, timestamp }) + } + + #[inline] + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let metadata = seq + .next_element::>()? + .ok_or_else(|| de::Error::invalid_length(0, &self))? + .map(MaybeUninit::new) + .unwrap_or(MaybeUninit::uninit()); + + let new_parent = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + let node_id = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + + let timestamp = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(3, &self))?; + + if seq.next_element::()?.is_some() { + return Err(de::Error::invalid_length(5, &self)); + } + + Ok(Op { metadata, new_parent, node_id, timestamp }) + } + } +} diff --git a/src/op_timestamp.rs b/src/op_timestamp.rs index b797f26..7a37667 100644 --- a/src/op_timestamp.rs +++ b/src/op_timestamp.rs @@ -1,4 +1,4 @@ -use core::{cmp, fmt}; +use core::{cmp, fmt, num, str}; use crate::PeerId; use crate::lamport_clock::LamportTimestamp; @@ -11,14 +11,24 @@ pub struct OpTimestamp { pub(crate) lamport_timestamp: LamportTimestamp, } +struct OpTimestampWrapper(OpTimestamp); + +enum OpTimestampFromStrError { + MissingDot, + InvalidLamportTimestamp(num::ParseIntError), + InvalidPeerId(num::ParseIntError), +} + +impl OpTimestamp { + #[inline] + pub(crate) fn display(&self) -> impl fmt::Display { + OpTimestampWrapper(*self) + } +} + impl fmt::Debug for OpTimestamp { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "OpTimestamp({}.{})", - self.lamport_timestamp.into_u64(), - self.created_by - ) + write!(f, "OpTimestamp({})", self.display()) } } @@ -39,6 +49,52 @@ impl Ord for OpTimestamp { } } +impl fmt::Display for OpTimestampWrapper { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let Self(OpTimestamp { created_by, lamport_timestamp }) = self; + write!(f, "{}.{created_by}", lamport_timestamp.into_u64()) + } +} + +impl str::FromStr for OpTimestampWrapper { + type Err = OpTimestampFromStrError; + + #[inline] + fn from_str(s: &str) -> Result { + let (lamport_timestamp_str, created_by_str) = + s.split_once('.').ok_or(OpTimestampFromStrError::MissingDot)?; + + let lamport_timestamp = lamport_timestamp_str + .parse::() + .map_err(OpTimestampFromStrError::InvalidLamportTimestamp)?; + + let created_by = created_by_str + .parse::() + .map_err(OpTimestampFromStrError::InvalidPeerId)?; + + Ok(OpTimestampWrapper(OpTimestamp { + created_by, + lamport_timestamp: LamportTimestamp::new(lamport_timestamp), + })) + } +} + +impl fmt::Display for OpTimestampFromStrError { + #[inline] + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + OpTimestampFromStrError::MissingDot => f.write_str("missing '.'"), + OpTimestampFromStrError::InvalidLamportTimestamp(err) => { + write!(f, "invalid lamport timestamp: {err}") + }, + OpTimestampFromStrError::InvalidPeerId(err) => { + write!(f, "invalid peer ID: {err}") + }, + } + } +} + #[cfg(feature = "encode")] mod encode { use super::*; @@ -63,3 +119,94 @@ mod encode { } } } + +#[cfg(feature = "serde")] +mod serde_impls { + use serde::de; + use serde::ser::{self, SerializeTuple}; + + use super::*; + use crate::node::GlobalNodeId; + + impl ser::Serialize for OpTimestamp { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ser::Serializer, + { + if serializer.is_human_readable() { + if self.lamport_timestamp.is_none() { + serializer.serialize_none() + } else { + serializer.serialize_str(&self.display().to_string()) + } + } else { + let mut tuple = serializer.serialize_tuple(2)?; + tuple.serialize_element(&self.lamport_timestamp)?; + tuple.serialize_element(&self.created_by)?; + tuple.end() + } + } + } + + impl<'de> de::Deserialize<'de> for OpTimestamp { + #[inline] + fn deserialize(deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + if deserializer.is_human_readable() { + match >::deserialize(deserializer)? { + Some(str) => str + .parse::() + .map(|OpTimestampWrapper(op_timestamp)| op_timestamp) + .map_err(de::Error::custom), + None => Ok(GlobalNodeId::NONE.creation_timestamp), + } + } else { + struct OpTimestampVisitor; + + impl<'de> de::Visitor<'de> for OpTimestampVisitor { + type Value = OpTimestamp; + + #[inline] + fn expecting( + &self, + formatter: &mut fmt::Formatter, + ) -> fmt::Result { + formatter.write_str( + "a tuple of (lamport_timestamp, peer_id)", + ) + } + + #[inline] + fn visit_seq( + self, + mut seq: A, + ) -> Result + where + A: de::SeqAccess<'de>, + { + let lamport_timestamp = + seq.next_element()?.ok_or_else(|| { + de::Error::invalid_length(0, &self) + })?; + + let created_by = + seq.next_element()?.ok_or_else(|| { + de::Error::invalid_length(1, &self) + })?; + + if seq.next_element::()?.is_some() { + return Err(de::Error::invalid_length(3, &self)); + } + + Ok(OpTimestamp { created_by, lamport_timestamp }) + } + } + + deserializer.deserialize_tuple(2, OpTimestampVisitor) + } + } + } +} diff --git a/src/serde.rs b/src/serde.rs new file mode 100644 index 0000000..a9b9923 --- /dev/null +++ b/src/serde.rs @@ -0,0 +1,281 @@ +//! TODO: docs. + +use core::fmt; +use core::marker::PhantomData; + +use serde::de; +use serde::ser::{self, SerializeStruct}; + +use crate::forest::Forest; +use crate::lamport_clock::LamportClock; +use crate::op_log::OpLog; +use crate::{PeerId, ProtocolVersion, Tree}; + +/// TODO: docs. +pub struct DeserializeTree { + metadata: PhantomData, + peer_id: PeerId, +} + +/// TODO: docs. +pub struct SerializeTree<'tree, M> { + /// The tree to serialize. + tree: &'tree Tree, + + /// Whether to serialize the tree state. + with_tree_state: bool, +} + +#[derive(serde::Deserialize)] +#[serde(field_identifier, rename_all = "snake_case")] +enum TreeField { + ProtocolVersion, + LamportClock, + OpLog, + Forest, +} + +struct MismatchedProtocolVersionError { + serialized_on: ProtocolVersion, + deserializing_on: ProtocolVersion, +} + +impl DeserializeTree { + #[inline] + pub(crate) fn new(peer_id: PeerId) -> Self { + Self { metadata: PhantomData, peer_id } + } +} + +impl<'tree, M> SerializeTree<'tree, M> { + const NAME: &'static str = "SerializeTree"; + + /// TODO: docs. + #[inline] + pub fn with_tree_state(mut self, with_tree_state: bool) -> Self { + self.with_tree_state = with_tree_state; + self + } + + #[inline] + pub(crate) fn new(tree: &'tree Tree) -> Self { + Self { tree, with_tree_state: true } + } +} + +impl TreeField { + const AS_SLICE: &'static [&'static str] = &[ + Self::ProtocolVersion.as_str(), + Self::LamportClock.as_str(), + Self::OpLog.as_str(), + Self::Forest.as_str(), + ]; + + #[inline] + const fn as_str(&self) -> &'static str { + match self { + Self::ProtocolVersion => "protocol_version", + Self::LamportClock => "lamport_clock", + Self::OpLog => "op_log", + Self::Forest => "forest", + } + } +} + +impl fmt::Debug for SerializeTree<'_, M> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(Self::NAME) + .field("tree", &self.tree) + .field("with_tree_state", &self.with_tree_state) + .finish() + } +} + +impl Copy for SerializeTree<'_, M> {} + +impl Clone for SerializeTree<'_, M> { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl ser::Serialize for SerializeTree<'_, M> { + #[inline] + fn serialize(&self, serializer: S) -> Result + where + S: ser::Serializer, + { + let mut tree = serializer + .serialize_struct(Self::NAME, TreeField::AS_SLICE.len())?; + + tree.serialize_field( + TreeField::ProtocolVersion.as_str(), + &crate::PROTOCOL_VERSION, + )?; + + tree.serialize_field( + TreeField::LamportClock.as_str(), + &self.tree.lamport_clock, + )?; + + tree.serialize_field(TreeField::OpLog.as_str(), &self.tree.op_log)?; + + tree.serialize_field( + TreeField::Forest.as_str(), + &self.with_tree_state.then_some(&self.tree.forest), + )?; + + tree.end() + } +} + +impl fmt::Debug for DeserializeTree { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("DeserializeTree") + .field("peer_id", &self.peer_id) + .finish() + } +} + +impl Copy for DeserializeTree {} + +impl Clone for DeserializeTree { + #[inline] + fn clone(&self) -> Self { + *self + } +} + +impl<'de, M: de::Deserialize<'de>> de::DeserializeSeed<'de> + for DeserializeTree +{ + type Value = Tree; + + #[inline] + fn deserialize(self, deserializer: D) -> Result + where + D: de::Deserializer<'de>, + { + deserializer.deserialize_struct( + SerializeTree::::NAME, + TreeField::AS_SLICE, + self, + ) + } +} + +impl<'de, M: de::Deserialize<'de>> de::Visitor<'de> for DeserializeTree { + type Value = Tree; + + #[inline] + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + write!(formatter, "a map representing a {}", SerializeTree::::NAME) + } + + #[inline] + fn visit_map(self, mut map: A) -> Result + where + A: de::MapAccess<'de>, + { + let mut protocol_version: Option = None; + let mut lamport_clock: Option = None; + let mut op_log: Option> = None; + let mut forest: Option = None; + + while let Some(tree_field) = map.next_key::()? { + match tree_field { + TreeField::ProtocolVersion => { + protocol_version = Some(map.next_value()?); + }, + TreeField::LamportClock => { + lamport_clock = Some(map.next_value()?); + }, + TreeField::OpLog => { + op_log = Some(map.next_value()?); + }, + TreeField::Forest => { + forest = map.next_value()?; + }, + } + } + + let protocol_version = protocol_version.ok_or_else(|| { + de::Error::missing_field(TreeField::ProtocolVersion.as_str()) + })?; + + if protocol_version != crate::PROTOCOL_VERSION { + return Err(de::Error::custom(MismatchedProtocolVersionError { + serialized_on: protocol_version, + deserializing_on: crate::PROTOCOL_VERSION, + })); + } + + let lamport_clock = lamport_clock.ok_or_else(|| { + de::Error::missing_field(TreeField::LamportClock.as_str()) + })?; + + let mut op_log = op_log.ok_or_else(|| { + de::Error::missing_field(TreeField::OpLog.as_str()) + })?; + + let forest = if let Some(forest) = forest { + forest + } else { + op_log.reconstruct_forest() + }; + + Ok(Tree { forest, op_log, lamport_clock, peer_id: self.peer_id }) + } + + #[inline] + fn visit_seq(self, mut seq: A) -> Result + where + A: de::SeqAccess<'de>, + { + let protocol_version: ProtocolVersion = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(0, &self))?; + + if protocol_version != crate::PROTOCOL_VERSION { + return Err(de::Error::custom(MismatchedProtocolVersionError { + serialized_on: protocol_version, + deserializing_on: crate::PROTOCOL_VERSION, + })); + } + + let lamport_clock: LamportClock = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(1, &self))?; + + let mut op_log: OpLog = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(2, &self))?; + + let maybe_forest = seq + .next_element::>()? + .ok_or_else(|| de::Error::invalid_length(3, &self))?; + + let forest = if let Some(forest) = maybe_forest { + forest + } else { + op_log.reconstruct_forest() + }; + + if seq.next_element::()?.is_some() { + return Err(de::Error::invalid_length(5, &self)); + } + + Ok(Tree { forest, op_log, lamport_clock, peer_id: self.peer_id }) + } +} + +impl fmt::Display for MismatchedProtocolVersionError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "protocol mismatch: serialized on {}, deserializing on {}", + self.serialized_on, self.deserializing_on + ) + } +} diff --git a/src/tree.rs b/src/tree.rs index 3758e68..de3260a 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -10,6 +10,8 @@ use crate::node::{self, IsVisible, Node, NodeMut, Visible}; use crate::node_id::{GlobalNodeId, LocalNodeId}; use crate::op_log::OpLog; use crate::op_timestamp::OpTimestamp; +#[cfg(feature = "serde")] +use crate::serde::{DeserializeTree, SerializeTree}; use crate::{PeerId, ops}; /// TODO: docs. @@ -22,10 +24,10 @@ pub struct Tree { pub(crate) op_log: OpLog, /// The lamport clock used to.. - lamport_clock: LamportClock, + pub(crate) lamport_clock: LamportClock, /// The ID of the peer that owns this `Tree`. - peer_id: PeerId, + pub(crate) peer_id: PeerId, } impl Tree { @@ -67,6 +69,14 @@ impl Tree { Ok(Self { lamport_clock, op_log, peer_id, forest }) } + /// TODO: docs. + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + #[inline] + pub fn deserialize(peer_id: PeerId) -> DeserializeTree { + DeserializeTree::new(peer_id) + } + /// TODO: docs. #[cfg(feature = "encode")] #[cfg_attr(docsrs, doc(cfg(feature = "encode")))] @@ -262,6 +272,14 @@ impl Tree { self.node_mut_inner(self.forest.root_id()) } + /// TODO: docs. + #[cfg(feature = "serde")] + #[cfg_attr(docsrs, doc(cfg(feature = "serde")))] + #[inline] + pub fn serialize(&self) -> SerializeTree<'_, M> { + SerializeTree::new(self) + } + #[doc(hidden)] #[track_caller] pub fn assert_invariants(&self) { @@ -340,7 +358,7 @@ impl Tree { } } -impl fmt::Debug for Tree { +impl fmt::Debug for Tree { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Tree") .field("peer_id", &self.peer_id) diff --git a/tests/encode.rs b/tests/encode.rs index 6ad02cb..ebf55f8 100644 --- a/tests/encode.rs +++ b/tests/encode.rs @@ -5,42 +5,42 @@ mod common; use common::{OpKind, Tree}; #[test] -fn fuzz_4_ops() { +fn encode_fuzz_4_ops_state() { fuzz(FuzzOpts { num_ops: 4, ..Default::default() }) } #[test] -fn fuzz_4_no_state() { +fn encode_fuzz_4_ops_no_state() { fuzz(FuzzOpts { encode_state: false, num_ops: 4, ..Default::default() }) } #[test] -fn fuzz_32_ops() { +fn encode_fuzz_32_ops_state() { fuzz(FuzzOpts { num_ops: 32, ..Default::default() }) } #[test] -fn fuzz_32_ops_no_state() { +fn encode_fuzz_32_ops_no_state() { fuzz(FuzzOpts { encode_state: false, num_ops: 32, ..Default::default() }) } #[test] -fn fuzz_256_ops() { +fn encode_fuzz_256_ops_state() { fuzz(FuzzOpts { num_ops: 256, ..Default::default() }) } #[test] -fn fuzz_256_ops_no_state() { +fn encode_fuzz_256_ops_no_state() { fuzz(FuzzOpts { encode_state: false, num_ops: 256, ..Default::default() }) } #[test] -fn fuzz_2048_ops() { +fn encode_fuzz_2048_ops_state() { fuzz(FuzzOpts { num_ops: 2048, ..Default::default() }) } #[test] -fn fuzz_2048_ops_no_state() { +fn encode_fuzz_2048_ops_no_state() { fuzz(FuzzOpts { encode_state: false, num_ops: 2048, ..Default::default() }) } diff --git a/tests/serde.rs b/tests/serde.rs new file mode 100644 index 0000000..b098b37 --- /dev/null +++ b/tests/serde.rs @@ -0,0 +1,252 @@ +#![allow(missing_docs)] + +mod common; + +use std::str; + +use common::{OpKind, Tree}; +use serde::Serialize; +use serde::de::DeserializeSeed; + +#[test] +fn serde_json_fuzz_4_ops_state() { + fuzz(FuzzOpts { num_ops: 4, ..Default::default() }) +} + +#[test] +fn serde_json_fuzz_4_ops_no_state() { + fuzz(FuzzOpts { + num_ops: 4, + serialize_tree_state: false, + ..Default::default() + }) +} + +#[test] +fn serde_json_fuzz_32_ops_state() { + fuzz(FuzzOpts { num_ops: 32, ..Default::default() }) +} + +#[test] +fn serde_json_fuzz_32_ops_no_state() { + fuzz(FuzzOpts { + num_ops: 32, + serialize_tree_state: false, + ..Default::default() + }) +} + +#[test] +fn serde_json_fuzz_256_ops_state() { + fuzz(FuzzOpts { num_ops: 256, ..Default::default() }) +} + +#[test] +fn serde_json_fuzz_256_ops_no_state() { + fuzz(FuzzOpts { + num_ops: 256, + serialize_tree_state: false, + ..Default::default() + }) +} + +#[test] +fn serde_json_fuzz_2048_ops_state() { + fuzz(FuzzOpts { num_ops: 2048, ..Default::default() }) +} + +#[test] +fn serde_json_fuzz_2048_ops_no_state() { + fuzz(FuzzOpts { + num_ops: 2048, + serialize_tree_state: false, + ..Default::default() + }) +} + +#[test] +fn serde_bincode_fuzz_4_ops_state() { + fuzz(FuzzOpts { + data_format: DataFormat::Bincode, + num_ops: 4, + ..Default::default() + }) +} + +#[test] +fn serde_bincode_fuzz_4_ops_no_state() { + fuzz(FuzzOpts { + data_format: DataFormat::Bincode, + num_ops: 4, + serialize_tree_state: false, + ..Default::default() + }) +} + +#[test] +fn serde_bincode_fuzz_32_ops_state() { + fuzz(FuzzOpts { + data_format: DataFormat::Bincode, + num_ops: 32, + ..Default::default() + }) +} + +#[test] +fn serde_bincode_fuzz_32_ops_no_state() { + fuzz(FuzzOpts { + data_format: DataFormat::Bincode, + num_ops: 32, + serialize_tree_state: false, + ..Default::default() + }) +} + +#[test] +fn serde_bincode_fuzz_256_ops_state() { + fuzz(FuzzOpts { num_ops: 256, ..Default::default() }) +} + +#[test] +fn serde_bincode_fuzz_256_ops_no_state() { + fuzz(FuzzOpts { + data_format: DataFormat::Bincode, + num_ops: 256, + serialize_tree_state: false, + ..Default::default() + }) +} + +#[test] +fn serde_bincode_fuzz_2048_ops_state() { + fuzz(FuzzOpts { + data_format: DataFormat::Bincode, + num_ops: 2048, + ..Default::default() + }) +} + +#[test] +fn serde_bincode_fuzz_2048_ops_no_state() { + fuzz(FuzzOpts { + data_format: DataFormat::Bincode, + num_ops: 2048, + serialize_tree_state: false, + ..Default::default() + }) +} + +#[derive(Debug)] +#[non_exhaustive] +struct FuzzOpts { + data_format: DataFormat, + num_ops: u32, + serialize_tree_state: bool, +} + +#[derive(Debug)] +enum DataFormat { + Bincode, + Json, +} + +impl DataFormat { + fn bincode_config(&self) -> impl bincode::config::Config { + bincode::config::standard() + } + + #[track_caller] + fn deserialize_seed<'de, T: DeserializeSeed<'de>>( + &self, + seed: T, + serialized: &'de [u8], + ) -> T::Value { + match self { + DataFormat::Bincode => { + let (deserialized, num_read) = + bincode::serde::seed_decode_from_slice( + seed, + serialized, + self.bincode_config(), + ) + .expect("failed to deserialize with bincode"); + + if num_read != serialized.len() { + panic!( + "deserialized {} bytes, expected {}", + num_read, + serialized.len() + ); + } + + deserialized + }, + + DataFormat::Json => { + let serialized_str = str::from_utf8(serialized).expect( + "failed to convert bytes to string for JSON \ + deserialization", + ); + + seed.deserialize(&mut serde_json::Deserializer::from_str( + serialized_str, + )) + .expect("failed to deserialize with JSON") + }, + } + } + + #[track_caller] + fn serialize(&self, value: &impl Serialize) -> Vec { + match self { + DataFormat::Bincode => { + bincode::serde::encode_to_vec(value, self.bincode_config()) + .expect("failed to serialize with bincode") + }, + + DataFormat::Json => serde_json::to_string_pretty(value) + .expect("failed to serialize with JSON") + .into_bytes(), + } + } +} + +impl Default for FuzzOpts { + fn default() -> Self { + Self { + data_format: DataFormat::Json, + num_ops: 1, + serialize_tree_state: true, + } + } +} + +#[track_caller] +fn fuzz(opts: FuzzOpts) { + assert!(opts.num_ops > 0); + + let mut tree = Tree::new(42u32, 1); + + let weights = + [(OpKind::Creation, 50), (OpKind::Deletion, 10), (OpKind::Move, 40)]; + + common::with_rng(|rng| { + // Perform a bunch of random ops. + for _ in 0..opts.num_ops { + tree.perform_random_op(weights, |_| 42, rng); + } + + // Serialize the tree. + let serialized = opts.data_format.serialize( + &tree.serialize().with_tree_state(opts.serialize_tree_state), + ); + + // Deserialize the tree. + let deserialized = opts + .data_format + .deserialize_seed(pando::Tree::deserialize(2), &serialized); + + // Assert that the deserialized tree is equal to the original tree. + assert_eq!(Tree::from(deserialized), tree); + }); +}