diff --git a/Cargo.lock b/Cargo.lock index 1676fa900f..1f43aaf6db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2447,6 +2447,7 @@ dependencies = [ "log", "octez-riscv-data", "octez-riscv-test-utils", + "perfect-derive", "proptest", "rand 0.9.2", "rocksdb", @@ -2455,6 +2456,7 @@ dependencies = [ "tempfile", "thiserror 2.0.18", "tokio", + "trait-set", ] [[package]] diff --git a/durable-storage/Cargo.toml b/durable-storage/Cargo.toml index 1a18e7ec4e..6a5c27d24c 100644 --- a/durable-storage/Cargo.toml +++ b/durable-storage/Cargo.toml @@ -20,6 +20,9 @@ rocksdb.workspace = true tempfile.workspace = true thiserror.workspace = true tokio.workspace = true +trait-set.workspace = true +perfect-derive.workspace = true + [dev-dependencies] proptest.workspace = true diff --git a/durable-storage/benches/avl_tree.rs b/durable-storage/benches/avl_tree.rs index f723dee18d..dbb6c026ec 100644 --- a/durable-storage/benches/avl_tree.rs +++ b/durable-storage/benches/avl_tree.rs @@ -10,6 +10,7 @@ use bytes::Bytes; use criterion::Criterion; use criterion::criterion_group; use criterion::criterion_main; +use octez_riscv_durable_storage::avl::resolver::ArcNodeId; use octez_riscv_durable_storage::avl::resolver::ArcResolver; use octez_riscv_durable_storage::avl::tree::Tree; use octez_riscv_durable_storage::key::Key; @@ -51,7 +52,7 @@ fn bench_avl_tree_operations(c: &mut Criterion) { let mut resolver = ArcResolver; // Setting up the tree - let mut tree = Tree::default(); + let mut tree = Tree::::default(); for key in &keys[..keys.len() / 2] { let random_data = generate_random_bytes_in_range(&mut rng, 1..20); let _ = tree diff --git a/durable-storage/src/avl/node.rs b/durable-storage/src/avl/node.rs index 577c11bc7f..91b29a498a 100644 --- a/durable-storage/src/avl/node.rs +++ b/durable-storage/src/avl/node.rs @@ -5,17 +5,18 @@ //! Interface for a Merklisable node of an AVL tree use std::cmp::Ordering; +#[cfg(test)] use std::fmt::Debug; -use std::sync::Arc; use std::sync::OnceLock; use bincode::Encode; use octez_riscv_data::components::bytes::Bytes; use octez_riscv_data::hash::Hash; use octez_riscv_data::mode::Normal; +use perfect_derive::perfect_derive; use super::tree::Tree; -use crate::avl::resolver::Resolver; +use crate::avl::resolver::NodeResolver; use crate::errors::OperationalError; use crate::key::Key; @@ -23,12 +24,12 @@ use crate::key::Key; pub type Value = Bytes; /// A node that supports rebalancing and Merklisation. -#[derive(Clone, Default, Debug)] -pub struct Node { +#[perfect_derive(Clone, Default, Debug)] +pub struct Node { key: Key, data: Value, - left: Tree, - right: Tree, + left: Tree, + right: Tree, /// A cache for the hash of this node. This uses `OnceLock` so that updating the cache is a /// non-mutating operation. @@ -52,7 +53,7 @@ struct NodeHashRepresentation<'a, Value> { balance_factor: i64, } -impl Node { +impl Node { /// Create a new leaf [`Node`] from the given key and data. pub(crate) fn new(key: Key, data: impl Into) -> Self { Node { @@ -67,7 +68,7 @@ impl Node { /// [`NodeHashRepresentation`], potentially re-hashing uncached [`Node`]s. pub(crate) fn to_encode( &self, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { // Recursively hashes any left child and its children let left = self @@ -111,14 +112,14 @@ impl Node { #[inline] /// A mutable reference to the left branch. - pub(super) fn left_mut(&mut self) -> &mut Tree { + pub(super) fn left_mut(&mut self) -> &mut Tree { self.invalidate_hash(); &mut self.left } #[inline] /// An immutable reference to the left branch. - pub(super) fn left_ref(&self) -> &Tree { + pub(super) fn left_ref(&self) -> &Tree { &self.left } @@ -128,8 +129,8 @@ impl Node { /// The subtree of the [`Node`] must already have balance factor in the range of -2..=2, else /// it is an invalid AVL tree. pub(super) fn rebalance( - node: &mut Arc, - resolver: &mut impl Resolver, Node>, + node: &mut Id, + resolver: &mut impl NodeResolver, ) -> Result<(), OperationalError> { let resolved_node = resolver.resolve(node)?; let balance_factor = resolved_node.balance_factor(); @@ -172,15 +173,15 @@ impl Node { /// - The [`Node`] at the root of the new subtree. /// - `true` if the [`Tree`] has shrunk in size. pub(super) fn replace_with_successor( - node: &mut Arc, - resolver: &mut impl Resolver, Node>, - ) -> Result<(Arc, bool), OperationalError> { - let node_mut = resolver.resolve_mut(node)?; + current: &mut Id, + resolver: &mut impl NodeResolver, + ) -> Result<(Id, bool), OperationalError> { + let current = resolver.resolve_mut(current)?; // If the right child has a left child, the successor is the min of the left child's subtree. let (mut successor, shrank) = if resolver .resolve( - node_mut + current .right_ref() .root() .expect("A node with a successor must have a right child"), @@ -189,8 +190,8 @@ impl Node { .root() .is_some() { - let right = node_mut.right_mut(); - let (mut min, _, shrank) = Tree::take_min(right, resolver)?; + let right = current.right_mut(); + let (mut min, _, shrank) = right.take_min(resolver)?; ( min.take() .expect("A node with a successor must have a right child"), @@ -198,39 +199,40 @@ impl Node { ) // If there is no left child of the right child, the successor is the right child. } else { - let mut successor = node_mut + let mut successor = current .right_mut() .take() .expect("A node with a successor must have a right child"); let successor_mut = resolver.resolve_mut(&mut successor)?; // Bump up the (optional) right child of the right child, causing the subtree to shrink. - node_mut.right = successor_mut.right.take().into(); + current.right = successor_mut.right.take().into(); (successor, true) }; let successor_mut = resolver.resolve_mut(&mut successor)?; - successor_mut.balance_factor = node_mut.balance_factor() - if shrank { 1 } else { 0 }; - successor_mut.left = std::mem::take(&mut node_mut.left); - successor_mut.right = std::mem::take(&mut node_mut.right); + successor_mut.balance_factor = current.balance_factor() - if shrank { 1 } else { 0 }; + successor_mut.left = std::mem::take(&mut current.left); + successor_mut.right = std::mem::take(&mut current.right); Self::rebalance(&mut successor, resolver)?; - let shrank = node_mut.balance_factor().abs() == 1 && successor.balance_factor == 0; + let successor_node = resolver.resolve(&successor)?; + let shrank = current.balance_factor().abs() == 1 && successor_node.balance_factor == 0; Ok((successor, shrank)) } #[inline] /// A mutable reference to the right branch. - pub(super) fn right_mut(&mut self) -> &mut Tree { + pub(super) fn right_mut(&mut self) -> &mut Tree { self.invalidate_hash(); &mut self.right } #[inline] /// An immutable reference to the right branch. - pub(super) fn right_ref(&self) -> &Tree { + pub(super) fn right_ref(&self) -> &Tree { &self.right } @@ -242,9 +244,9 @@ impl Node { /// - The minimum [`Tree`]'s right child, if it hasn't been moved to its new position. /// - True if this [`Node`]'s subtree has shrunk in size. pub(super) fn take_min( - node: &mut Arc, - resolver: &mut impl Resolver, Node>, - ) -> Result<(Tree, Tree, bool), OperationalError> { + node: &mut Id, + resolver: &mut impl NodeResolver, + ) -> Result<(Tree, Tree, bool), OperationalError> { let node_mut = resolver.resolve_mut(node)?; let old_node_bf = node_mut.balance_factor(); @@ -277,7 +279,7 @@ impl Node { key: &Key, offset: usize, data: impl FnOnce(&mut Value), - resolver: &mut impl Resolver, Node>, + resolver: &mut impl NodeResolver, ) -> Result { // SAFETY: The default recursion limit in Rust is 128 // see: @@ -341,8 +343,8 @@ impl Node { /// Assumes this [`Node`]'s balance factor is 2 and the right [`Node`]'s balance factor is +1 /// or 0. fn rotate_left( - node: &mut Arc, - resolver: &mut impl Resolver, Node>, + node: &mut Id, + resolver: &mut impl NodeResolver, ) -> Result<(), OperationalError> { let node_mut = resolver.resolve_mut(node)?; let mut right = node_mut @@ -405,8 +407,8 @@ impl Node { /// /// Assumes this [`Node`]'s balance factor is -2 and the left [Node]'s balance factor is +1. fn rotate_left_right( - node: &mut Arc, - resolver: &mut impl Resolver, Node>, + node: &mut Id, + resolver: &mut impl NodeResolver, ) -> Result<(), OperationalError> { let node_mut = resolver.resolve_mut(node)?; @@ -421,14 +423,14 @@ impl Node { .take() .expect("Left's right child must exist for the left rotation of the left node"); + let left_right_mut = resolver.resolve_mut(&mut left_right)?; + // From the `rotate_left` derivation, the first rotation does: // new_A_bf_1 = old_A_bf - 1 + std::cmp::min(-A.right.balance_factor, 0) // As this function assumes old_A_bf is +1: // new_A_bf_1 = std::cmp::min(-A.right.balance_factor, 0) // The second rotation doesn't mutate A's subtree, so the final balance factor is: - left_mut.balance_factor = std::cmp::min(-left_right.balance_factor, 0); - - let left_right_mut = resolver.resolve_mut(&mut left_right)?; + left_mut.balance_factor = std::cmp::min(-left_right_mut.balance_factor, 0); // B's right child is between B and B, it's moved to node's left node_mut.left = left_right_mut.right.take().into(); @@ -468,8 +470,8 @@ impl Node { /// Assumes this [`Node`]'s balance factor is -2 and the left [`Node`]'s balance factor is -1 /// or 0. fn rotate_right( - node: &mut Arc, - resolver: &mut impl Resolver, Node>, + node: &mut Id, + resolver: &mut impl NodeResolver, ) -> Result<(), OperationalError> { let node_mut = resolver.resolve_mut(node).expect("Node must exist."); let mut left = node_mut @@ -532,8 +534,8 @@ impl Node { /// /// Assumes this [`Node`]'s balance factor is +2 and the left [Node]'s balance factor is -1. fn rotate_right_left( - node: &mut Arc, - resolver: &mut impl Resolver, Node>, + node: &mut Id, + resolver: &mut impl NodeResolver, ) -> Result<(), OperationalError> { let node_mut = resolver.resolve_mut(node)?; let mut right = node_mut @@ -547,14 +549,14 @@ impl Node { .take() .expect("Right's left child must exist for the right rotation of the right node"); + let right_left_mut = resolver.resolve_mut(&mut right_left)?; + // From the `rotate_right` derivation, the first rotation does: // new_A_bf_1 = old_A_bf + 1 + std::cmp::max(0, -A.left.balance_factor) // As this function assumes old_A_bf is -1: // new_A_bf_1 = std::cmp::max(0, -A.left.balance_factor) // The second rotation doesn't mutate A's subtree, so the final balance factor is: - right_mut.balance_factor = std::cmp::max(0, -right_left.balance_factor); - - let right_left_mut = resolver.resolve_mut(&mut right_left)?; + right_mut.balance_factor = std::cmp::max(0, -right_left_mut.balance_factor); // B's left child is between node and B, it's moved to node's right node_mut.right = right_left_mut.left.take().into(); @@ -583,9 +585,9 @@ impl Node { /// /// If the hash has been cached, the memo is returned. Otherwise, the hash is calculated and /// cached. -pub(crate) fn hash<'a>( - node: &'a Arc, - resolver: &impl Resolver, Node>, +pub(crate) fn hash<'a, Id: Clone>( + node: &'a Id, + resolver: &impl NodeResolver, ) -> Result<&'a Hash, OperationalError> { let resolved = resolver.resolve(node)?; @@ -601,7 +603,7 @@ pub(crate) fn hash<'a>( } #[cfg(test)] -impl Node { +impl Node { #[inline] /// The data stored in the [`Node`]. pub(crate) fn data(&self) -> &Value { @@ -610,9 +612,9 @@ impl Node { /// The data stored in a [`Node`] within the subtree of this [`Node`] with a given [`Key`] . pub(super) fn get<'a>( - mut node: &'a Arc, + mut node: &'a Id, key: &Key, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result, OperationalError> { loop { let resolved_node = resolver.resolve(node)?; @@ -637,7 +639,7 @@ impl Node { /// Returns true if the balance factors stored in the [`Node`]'s subtree are correct. pub(super) fn has_correct_balance_factors( &self, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { let left_height = self.left_ref().height(resolver)?; let right_height = self.right_ref().height(resolver)?; @@ -655,10 +657,7 @@ impl Node { } /// Returns the height of this [`Node`]'s subtree. - pub(super) fn height( - &self, - resolver: &impl Resolver, Node>, - ) -> Result { + pub(super) fn height(&self, resolver: &impl NodeResolver) -> Result { let left_height = self.left_ref().height(resolver)?; let right_height = self.right_ref().height(resolver)?; Ok(1 + std::cmp::max(left_height, right_height)) @@ -667,7 +666,7 @@ impl Node { /// Returns true if this [`Node`]'s subtree is balanced. pub(super) fn is_balanced( &self, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { let balance_factor = self.balance_factor(); if balance_factor.abs() > 1 { @@ -684,7 +683,7 @@ impl Node { &self, min: &Key, max: &Key, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { if self.key() < min || self.key() > max { return Ok(false); diff --git a/durable-storage/src/avl/resolver.rs b/durable-storage/src/avl/resolver.rs index bba4bdadee..931c03c3a2 100644 --- a/durable-storage/src/avl/resolver.rs +++ b/durable-storage/src/avl/resolver.rs @@ -7,12 +7,21 @@ //! [`Tree`]: crate::avl::tree::Tree //! [`Node`]: crate::avl::node::Node +use std::borrow::Borrow; +use std::borrow::BorrowMut; use std::sync::Arc; +use perfect_derive::perfect_derive; +use trait_set::trait_set; + +use crate::avl::node::Node; use crate::errors::OperationalError; /// Trait for resolving identifiers to values. pub trait Resolver { + /// Create a new `Id` from a `Value`. + fn create_id(&self, value: Value) -> Id; + /// Resolve an identifier to a value. fn resolve<'a>(&self, id: &'a Id) -> Result<&'a Value, OperationalError>; @@ -24,12 +33,44 @@ pub trait Resolver { #[derive(Clone, Debug)] pub struct ArcResolver; -impl Resolver, T> for ArcResolver { - fn resolve<'a>(&self, id: &'a Arc) -> Result<&'a T, OperationalError> { - Ok(id.as_ref()) +impl> + From>> Resolver for ArcResolver { + fn create_id(&self, value: T) -> Id { + Id::from(Arc::new(value)) + } + + fn resolve<'a>(&self, id: &'a Id) -> Result<&'a T, OperationalError> { + Ok(id.borrow().as_ref()) + } + + fn resolve_mut<'a>(&mut self, id: &'a mut Id) -> Result<&'a mut T, OperationalError> { + Ok(Arc::make_mut(id.borrow_mut())) + } +} + +/// An identifier for a [`Node`] that is stored as an [`Arc`]. +#[derive(Debug)] +#[perfect_derive(Clone, Default)] +pub struct ArcNodeId(Arc>); + +impl Borrow>> for ArcNodeId { + fn borrow(&self) -> &Arc> { + &self.0 + } +} + +impl BorrowMut>> for ArcNodeId { + fn borrow_mut(&mut self) -> &mut Arc> { + &mut self.0 } +} - fn resolve_mut<'a>(&mut self, id: &'a mut Arc) -> Result<&'a mut T, OperationalError> { - Ok(Arc::make_mut(id)) +impl From>> for ArcNodeId { + fn from(value: Arc>) -> Self { + ArcNodeId(value) } } + +trait_set! { + /// A resolver for [`Node`] identifiers. + pub trait NodeResolver = Resolver>; +} diff --git a/durable-storage/src/avl/tree.rs b/durable-storage/src/avl/tree.rs index 567e75106e..75f75ca67c 100644 --- a/durable-storage/src/avl/tree.rs +++ b/durable-storage/src/avl/tree.rs @@ -5,31 +5,32 @@ //! Interface for an optional root [`Node`] of a Merklisable AVL tree use std::cmp::Ordering; +#[cfg(test)] use std::fmt::Debug; -use std::sync::Arc; use octez_riscv_data::hash::Hash; +use perfect_derive::perfect_derive; use super::node::Node; use super::node::Value; -use crate::avl::resolver::Resolver; +use crate::avl::resolver::NodeResolver; use crate::errors::OperationalError; #[cfg(test)] use crate::key::KEY_MAX_SIZE; use crate::key::Key; /// A key-value store tree with left and right nodes that supports traversal and value retrieval. -#[derive(Clone, Default, Debug)] -pub struct Tree(Option>); +#[perfect_derive(Clone, Default, Debug)] +pub struct Tree(Option); -impl Tree { +impl Tree { /// Delete the [`Node`] in the [`Tree`] with a given key. /// /// Returns true if the [`Tree`] has shrunk in size. pub fn delete( &mut self, key: &Key, - resolver: &mut impl Resolver, Node>, + resolver: &mut impl NodeResolver, ) -> Result { let old_balance_factor = self.balance_factor(resolver)?; let Some(node) = self.root_mut() else { @@ -86,7 +87,7 @@ impl Tree { &mut self, key: &Key, data: &[u8], - resolver: &mut impl Resolver, Node>, + resolver: &mut impl NodeResolver, ) -> Result { self.upsert(key, 0, |old_data| old_data.set(data), resolver) } @@ -95,28 +96,32 @@ impl Tree { /// /// If the [`struct@Hash`] has been cached, the memo is returned. Otherwise, the /// [`struct@Hash`] is calculated and cached. - pub(crate) fn hash( - &self, - resolver: &impl Resolver, Node>, - ) -> Result { + pub(crate) fn hash(&self, resolver: &impl NodeResolver) -> Result { let encodable = self .0 - .as_deref() - .map(|node| node.to_encode(resolver)) + .as_ref() + .map(|node| { + let node = resolver.resolve(node).expect("Node must exist."); + node.to_encode(resolver) + }) .transpose()?; Ok(Hash::hash_encodable(encodable).expect("Should be hashable")) } /// Creates an in-order iterator for the [`Node`]s in the [`Tree`] - pub(crate) fn iter(&self) -> TreeIterator { + pub(crate) fn iter<'a, R: NodeResolver>( + &'a self, + resolver: &'a R, + ) -> TreeIterator<'a, Id, R> { TreeIterator { stack: vec![], current: self, + resolver, } } /// Take the root [`Node`] out of this tree, leaving the [`Tree`] empty. - pub(crate) const fn take(&mut self) -> Option> { + pub(crate) const fn take(&mut self) -> Option { self.0.take() } @@ -124,7 +129,7 @@ impl Tree { /// The difference in heights between any child branches in the [`Tree`]. pub(super) fn balance_factor( &self, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { let Some(node) = self.root() else { return Ok(0); @@ -137,13 +142,13 @@ impl Tree { #[inline] /// A reference to the root [`Node`]. - pub(super) fn root(&self) -> Option<&Arc> { + pub(super) fn root(&self) -> Option<&Id> { self.0.as_ref() } #[inline] /// A mutable reference to the root [`Node`]. - pub(super) fn root_mut(&mut self) -> Option<&mut Arc> { + pub(super) fn root_mut(&mut self) -> Option<&mut Id> { self.0.as_mut() } @@ -156,8 +161,8 @@ impl Tree { /// - True if the [`Tree`] has shrunk in size. pub(super) fn take_min( &mut self, - resolver: &mut impl Resolver, Node>, - ) -> Result<(Tree, Tree, bool), OperationalError> { + resolver: &mut impl NodeResolver, + ) -> Result<(Tree, Tree, bool), OperationalError> { let Some(node_arc) = self.root_mut() else { return Ok((None.into(), None.into(), false)); }; @@ -184,7 +189,7 @@ impl Tree { key: &Key, offset: usize, data: impl FnOnce(&mut Value), - resolver: &mut impl Resolver, Node>, + resolver: &mut impl NodeResolver, ) -> Result { let node = self.root_mut(); let Some(node) = node else { @@ -199,7 +204,7 @@ impl Tree { data(&mut new_data); // The key does not exist and a new `Node` shall be created. - self.0 = Some(Arc::new(Node::new(key.clone(), new_data))); + self.0 = Some(resolver.create_id(Node::new(key.clone(), new_data))); return Ok(true); }; let grew = resolver @@ -222,7 +227,7 @@ impl Tree { key: &Key, offset: usize, data: &[u8], - resolver: &mut impl Resolver, Node>, + resolver: &mut impl NodeResolver, ) -> Result { self.upsert( key, @@ -253,10 +258,7 @@ impl Tree { /// /// The [`Tree`] must already have balance factor in the range of -2..=2, else it is an invalid /// AVL tree. - fn rebalance( - &mut self, - resolver: &mut impl Resolver, Node>, - ) -> Result<(), OperationalError> { + fn rebalance(&mut self, resolver: &mut impl NodeResolver) -> Result<(), OperationalError> { match self.root_mut() { Some(node) => Node::rebalance(node, resolver), None => Ok(()), @@ -264,41 +266,45 @@ impl Tree { } } -impl From>> for Tree { - fn from(node: Option>) -> Self { +impl From> for Tree { + fn from(node: Option) -> Self { Tree(node) } } /// Used for iterating through the nodes of the [`Tree`] tree in order. -pub(crate) struct TreeIterator<'a> { - stack: Vec<&'a Arc>, - current: &'a Tree, +pub(crate) struct TreeIterator<'a, Id, R> { + stack: Vec<&'a Id>, + current: &'a Tree, + resolver: &'a R, } -impl<'a> Iterator for TreeIterator<'a> { - type Item = &'a Arc; +impl<'a, Id: Clone, R: NodeResolver> Iterator for TreeIterator<'a, Id, R> { + type Item = &'a Id; fn next(&mut self) -> Option { while let Some(node) = self.current.root() { self.stack.push(node); - self.current = node.left_ref(); + + let resolved_node = self.resolver.resolve(node).ok()?; + self.current = resolved_node.left_ref(); } let ret = self.stack.pop()?; - self.current = ret.right_ref(); + let resolved_node = self.resolver.resolve(ret).ok()?; + self.current = resolved_node.right_ref(); Some(ret) } } #[cfg(test)] -impl Tree { +impl Tree { #[inline] /// The data stored in a [`Node`] in the [`Tree`] with a given [`Key`]. pub fn get( &self, key: &Key, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result, OperationalError> { let Some(node) = self.root() else { return Ok(None); @@ -307,10 +313,7 @@ impl Tree { } /// Asserts that the [`Tree`] is a valid AVL tree - pub(crate) fn check( - &self, - resolver: &impl Resolver, Node>, - ) -> Result<(), OperationalError> { + pub(crate) fn check(&self, resolver: &impl NodeResolver) -> Result<(), OperationalError> { let inorder = self.is_inorder(resolver)?; let is_balanced = self.is_balanced(resolver)?; let has_correct_balance_factors = self.has_correct_balance_factors(resolver)?; @@ -329,7 +332,7 @@ impl Tree { /// Returns true if the [`Tree`] is in-order. pub(crate) fn is_inorder( &self, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { self.is_inorder_inner( &Key::new(&[u8::MIN]).expect("Size less than KEY_MAX_SIZE"), @@ -341,7 +344,7 @@ impl Tree { /// Returns true if the balance factors stored in any [`Node`]'s subtree are correct. pub(super) fn has_correct_balance_factors( &self, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { match self.root() { None => Ok(true), @@ -352,10 +355,7 @@ impl Tree { } /// Returns the height of the [`Tree`]. - pub(super) fn height( - &self, - resolver: &impl Resolver, Node>, - ) -> Result { + pub(super) fn height(&self, resolver: &impl NodeResolver) -> Result { match self.root() { None => Ok(0), Some(node) => resolver.resolve(node).map(|res| res.height(resolver))?, @@ -365,7 +365,7 @@ impl Tree { /// Returns true if the [`Tree`] is balanced. pub(super) fn is_balanced( &self, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { match self.root() { None => Ok(true), @@ -380,7 +380,7 @@ impl Tree { &self, min: &Key, max: &Key, - resolver: &impl Resolver, Node>, + resolver: &impl NodeResolver, ) -> Result { match self.root() { None => Ok(true), @@ -393,15 +393,20 @@ impl Tree { #[cfg(test)] mod tests { + use std::borrow::Borrow; + use std::borrow::BorrowMut; use std::collections::BTreeMap; use std::io::prelude::*; + use std::sync::Arc; use bytes::Bytes; use goldenfile::Mint; use proptest::prelude::*; use super::*; + use crate::avl::resolver::ArcNodeId; use crate::avl::resolver::ArcResolver; + use crate::avl::resolver::Resolver; use crate::key::KEY_MAX_SIZE; use crate::key::Key; @@ -445,11 +450,13 @@ mod tests { }) } - fn compare_tree_to_reference(tree: &Tree, reference: &BTreeMap) { - let tree_iter = tree.iter(); + fn compare_tree_to_reference(tree: &Tree, reference: &BTreeMap) { + let resolver = ArcResolver; + let tree_iter = tree.iter(&resolver); let mut reference_iter = reference.iter(); for node in tree_iter { if let Some((key, value)) = reference_iter.next() { + let node = resolver.resolve(node).expect("Node must exist."); assert_eq!(node.key(), key); assert_eq!(node.data(), value); } else { @@ -463,10 +470,72 @@ mod tests { ); } + #[derive(Debug)] + struct FailOnKeyResolver { + fail_on: Key, + } + + impl Resolver> for FailOnKeyResolver { + fn resolve<'a>(&self, id: &'a ArcNodeId) -> Result<&'a Node, OperationalError> { + if >>>::borrow(id).key() == &self.fail_on { + return Err(OperationalError::Resolver); + } + Ok(>>>::borrow(id).as_ref()) + } + + fn resolve_mut<'a>( + &mut self, + id: &'a mut ArcNodeId, + ) -> Result<&'a mut Node, OperationalError> { + if >>>::borrow(id).key() == &self.fail_on { + return Err(OperationalError::Resolver); + } + Ok(Arc::make_mut(>, + >>::borrow_mut(id))) + } + + fn create_id(&self, value: Node) -> ArcNodeId { + Arc::new(value).into() + } + } + + #[test] + fn get_distinguishes_missing_key_from_resolution_error() { + let root = Key::new(&[2]).expect("The key should be valid."); + let left = Key::new(&[1]).expect("The key should be valid."); + let missing = Key::new(&[0]).expect("The key should be valid."); + let no_failure_key = Key::new(&[255]).expect("The key should be valid."); + + let mut tree: Tree = Default::default(); + let mut setup_resolver = ArcResolver; + tree.set(&root, b"root", &mut setup_resolver) + .expect("Setting the root should succeed."); + tree.set(&left, b"left", &mut setup_resolver) + .expect("Setting the left child should succeed."); + + let ok_resolver = FailOnKeyResolver { + fail_on: no_failure_key, + }; + assert!( + matches!(tree.get(&missing, &ok_resolver), Ok(None)), + "Missing key lookup should be distinguishable as Ok(None)." + ); + + let failing_resolver = FailOnKeyResolver { fail_on: left }; + assert!( + matches!( + tree.get(&missing, &failing_resolver), + Err(OperationalError::Resolver) + ), + "Resolver failures should be propagated as Err." + ); + } + proptest! { #[test] fn avl_driver_test(operations in (1usize..500usize).prop_flat_map(operations_strategy)) { - let mut tree: Tree = Default::default(); + let mut tree: Tree = Default::default(); let mut reference: BTreeMap = BTreeMap::new(); let mut resolver = ArcResolver; for operation in operations { @@ -507,7 +576,7 @@ mod tests { #[test] fn test_hash_consistency() { - let mut tree: Tree = Default::default(); + let mut tree: Tree = Default::default(); let mut resolver = ArcResolver; let data = ["42", "6 * 9", "1337", "31337"]; diff --git a/durable-storage/src/errors.rs b/durable-storage/src/errors.rs index 606fcdad7c..aa543bfe98 100644 --- a/durable-storage/src/errors.rs +++ b/durable-storage/src/errors.rs @@ -72,6 +72,9 @@ pub enum OperationalError { #[error("Error while writing to file: {error}")] FileWriteFailed { error: std::io::Error }, + + #[error("Error in resolution of ID.")] + Resolver, } /// Errors that occur because of incorrect usage diff --git a/durable-storage/src/merkle_layer.rs b/durable-storage/src/merkle_layer.rs index a6a1aec03f..de142c505c 100644 --- a/durable-storage/src/merkle_layer.rs +++ b/durable-storage/src/merkle_layer.rs @@ -7,7 +7,9 @@ use std::sync::Arc; use octez_riscv_data::hash::Hash; use octez_riscv_data::serialisation; +use crate::avl::resolver::ArcNodeId; use crate::avl::resolver::ArcResolver; +use crate::avl::resolver::Resolver; use crate::avl::tree::Tree; use crate::commit::CommitId; use crate::errors::OperationalError; @@ -18,7 +20,7 @@ use crate::persistence_layer::PersistenceLayer; /// A layer for transforming data into a Merkelised representation before commitment to the [PersistenceLayer]. #[derive(Clone, Debug)] pub struct MerkleLayer { - tree: Tree, + tree: Tree, persistence: Arc, resolver: ArcResolver, } @@ -54,7 +56,8 @@ impl MerkleLayer { // iteration of the nodes the hashes are // calculated during the encoding of the node // if necessary. - for node in self.tree.iter() { + for node in self.tree.iter(&self.resolver) { + let node = self.resolver.resolve(node)?; let encoded = node.to_encode(&self.resolver)?; let value = serialisation::serialise(encoded) .expect("Serialisation of node data should not fail"); @@ -100,6 +103,8 @@ mod tests { use super::MerkleLayer; use crate::avl::node::Node; use crate::avl::node::Value; + use crate::avl::resolver::ArcNodeId; + use crate::avl::resolver::Resolver; use crate::avl::tree::Tree; use crate::errors::OperationalError; use crate::key::Key; @@ -107,7 +112,7 @@ mod tests { use crate::repo::DirectoryManager; impl MerkleLayer { - fn tree(&self) -> &Tree { + fn tree(&self) -> &Tree { &self.tree } @@ -171,11 +176,11 @@ mod tests { ml.hash().expect("hash operation should succeed.") ); - let old_node1 = Node::new(keys[0].clone(), Bytes::copy_from_slice(&data[0])); - let new_node1 = Node::new(keys[0].clone(), cow_data.as_bytes()); + let old_node1 = Node::::new(keys[0].clone(), Bytes::copy_from_slice(&data[0])); + let new_node1 = Node::::new(keys[0].clone(), cow_data.as_bytes()); - let node2 = Node::new(keys[1].clone(), Bytes::copy_from_slice(&data[1])); - let node3 = Node::new(keys[2].clone(), Bytes::copy_from_slice(&data[2])); + let node2 = Node::::new(keys[1].clone(), Bytes::copy_from_slice(&data[1])); + let node3 = Node::::new(keys[2].clone(), Bytes::copy_from_slice(&data[2])); assert_eq!( &old_node1.data(), @@ -289,7 +294,7 @@ mod tests { ml.hash().expect("hash operation should succeed.") ); - let node = Node::new(key.clone(), data); + let node = Node::::new(key.clone(), data); let get_node = ml .get(&key) .expect("The node should be retrieved successfully") @@ -310,7 +315,7 @@ mod tests { ml.set(&key, &data).expect("setting node should succeed"); let old_hash = ml.hash().expect("hash operation should succeed."); - let node = Node::new(key.clone(), data); + let node = Node::::new(key.clone(), data); let get_node = ml .get(&key) .expect("The node should be retrieved successfully") @@ -326,7 +331,7 @@ mod tests { .expect("The tree should be retrieved successfully."), "AVL isn't in order: {ml:?}" ); - let node = Node::new(key.clone(), data2); + let node = Node::::new(key.clone(), data2); let get_node = ml .get(&key) .expect("The node should be retrieved successfully") @@ -812,7 +817,7 @@ mod tests { let old_hash = ml.hash().expect("hash operation should succeed."); ml.write(&key, 0, &data).expect("write should succeed."); - let node = Node::new(key.clone(), data); + let node = Node::::new(key.clone(), data); let get_node = ml .get(&key) .expect("The node should be retrieved successfully") @@ -835,7 +840,7 @@ mod tests { let old_hash = ml.hash().expect("hash operation should succeed."); let data_len = data.len(); - let node = Node::new(key.clone(), data); + let node = Node::::new(key.clone(), data); let get_node = ml .get(&key) .expect("The node should be retrieved successfully") @@ -851,7 +856,7 @@ mod tests { .expect("The tree should be retrieved successfully."), "AVL isn't in order: {ml:?}" ); - let node = Node::new(key.clone(), Bytes::from("a good value")); + let node = Node::::new(key.clone(), Bytes::from("a good value")); let get_node = ml .get(&key) .expect("The node should be retrieved successfully") @@ -941,14 +946,20 @@ mod tests { .commit() .expect("The commit operation should not fail"); - for node in merkle_layer.tree.iter() { + for node_id in merkle_layer.tree.iter(&merkle_layer.resolver) { + let node = merkle_layer + .resolver + .resolve(node_id) + .expect("Node should be resolvable."); + let encoded = node .to_encode(&merkle_layer.resolver) .expect("Node should be encodable"); let serialised = octez_riscv_data::serialisation::serialise(encoded) .expect("We should be able to serialise the node"); - let node_hash = crate::avl::node::hash(node, &merkle_layer.resolver) + let node_hash = crate::avl::node::hash(node_id, &merkle_layer.resolver) .expect("Resolving the node should succeed."); + let blob = merkle_layer .persistence .blob_get(node_hash)