Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions durable-storage/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ rocksdb.workspace = true
tempfile.workspace = true
thiserror.workspace = true
tokio.workspace = true
trait-set.workspace = true
perfect-derive.workspace = true


[dev-dependencies]
proptest.workspace = true
Expand Down
3 changes: 2 additions & 1 deletion durable-storage/benches/avl_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use bytes::Bytes;
use criterion::Criterion;
use criterion::criterion_group;
use criterion::criterion_main;
use octez_riscv_durable_storage::avl::resolver::ArcNodeId;
use octez_riscv_durable_storage::avl::resolver::ArcResolver;
use octez_riscv_durable_storage::avl::tree::Tree;
use octez_riscv_durable_storage::key::Key;
Expand Down Expand Up @@ -51,7 +52,7 @@ fn bench_avl_tree_operations(c: &mut Criterion) {
let mut resolver = ArcResolver;

// Setting up the tree
let mut tree = Tree::default();
let mut tree = Tree::<ArcNodeId>::default();
for key in &keys[..keys.len() / 2] {
let random_data = generate_random_bytes_in_range(&mut rng, 1..20);
let _ = tree
Expand Down
115 changes: 57 additions & 58 deletions durable-storage/src/avl/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,31 @@
//! Interface for a Merklisable node of an AVL tree

use std::cmp::Ordering;
#[cfg(test)]
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::OnceLock;

use bincode::Encode;
use octez_riscv_data::components::bytes::Bytes;
use octez_riscv_data::hash::Hash;
use octez_riscv_data::mode::Normal;
use perfect_derive::perfect_derive;

use super::tree::Tree;
use crate::avl::resolver::Resolver;
use crate::avl::resolver::NodeResolver;
use crate::errors::OperationalError;
use crate::key::Key;

/// Value stored in a node
pub type Value = Bytes<Normal>;

/// A node that supports rebalancing and Merklisation.
#[derive(Clone, Default, Debug)]
pub struct Node {
#[perfect_derive(Clone, Default, Debug)]
pub struct Node<Id: Clone> {
key: Key,
data: Value,
left: Tree,
right: Tree,
left: Tree<Id>,
right: Tree<Id>,

/// A cache for the hash of this node. This uses `OnceLock` so that updating the cache is a
/// non-mutating operation.
Expand All @@ -52,7 +53,7 @@ struct NodeHashRepresentation<'a, Value> {
balance_factor: i64,
}

impl Node {
impl<Id: Clone> Node<Id> {
/// Create a new leaf [`Node`] from the given key and data.
pub(crate) fn new(key: Key, data: impl Into<Value>) -> Self {
Node {
Expand All @@ -67,7 +68,7 @@ impl Node {
/// [`NodeHashRepresentation`], potentially re-hashing uncached [`Node`]s.
pub(crate) fn to_encode(
&self,
resolver: &impl Resolver<Arc<Node>, Node>,
resolver: &impl NodeResolver<Id>,
) -> Result<impl Encode + '_, OperationalError> {
// Recursively hashes any left child and its children
let left = self
Expand Down Expand Up @@ -111,14 +112,14 @@ impl Node {

#[inline]
/// A mutable reference to the left branch.
pub(super) fn left_mut(&mut self) -> &mut Tree {
pub(super) fn left_mut(&mut self) -> &mut Tree<Id> {
self.invalidate_hash();
&mut self.left
}

#[inline]
/// An immutable reference to the left branch.
pub(super) fn left_ref(&self) -> &Tree {
pub(super) fn left_ref(&self) -> &Tree<Id> {
&self.left
}

Expand All @@ -128,8 +129,8 @@ impl Node {
/// The subtree of the [`Node`] must already have balance factor in the range of -2..=2, else
/// it is an invalid AVL tree.
pub(super) fn rebalance(
node: &mut Arc<Node>,
resolver: &mut impl Resolver<Arc<Node>, Node>,
node: &mut Id,
resolver: &mut impl NodeResolver<Id>,
) -> Result<(), OperationalError> {
let resolved_node = resolver.resolve(node)?;
let balance_factor = resolved_node.balance_factor();
Expand Down Expand Up @@ -172,15 +173,15 @@ impl Node {
/// - The [`Node`] at the root of the new subtree.
/// - `true` if the [`Tree`] has shrunk in size.
pub(super) fn replace_with_successor(
node: &mut Arc<Node>,
resolver: &mut impl Resolver<Arc<Node>, Node>,
) -> Result<(Arc<Node>, bool), OperationalError> {
let node_mut = resolver.resolve_mut(node)?;
current: &mut Id,
resolver: &mut impl NodeResolver<Id>,
) -> Result<(Id, bool), OperationalError> {
let current = resolver.resolve_mut(current)?;

// If the right child has a left child, the successor is the min of the left child's subtree.
let (mut successor, shrank) = if resolver
.resolve(
node_mut
current
.right_ref()
.root()
.expect("A node with a successor must have a right child"),
Expand All @@ -189,48 +190,49 @@ impl Node {
.root()
.is_some()
{
let right = node_mut.right_mut();
let (mut min, _, shrank) = Tree::take_min(right, resolver)?;
let right = current.right_mut();
let (mut min, _, shrank) = right.take_min(resolver)?;
(
min.take()
.expect("A node with a successor must have a right child"),
shrank,
)
// If there is no left child of the right child, the successor is the right child.
} else {
let mut successor = node_mut
let mut successor = current
.right_mut()
.take()
.expect("A node with a successor must have a right child");
let successor_mut = resolver.resolve_mut(&mut successor)?;

// Bump up the (optional) right child of the right child, causing the subtree to shrink.
node_mut.right = successor_mut.right.take().into();
current.right = successor_mut.right.take().into();
(successor, true)
};

let successor_mut = resolver.resolve_mut(&mut successor)?;

successor_mut.balance_factor = node_mut.balance_factor() - if shrank { 1 } else { 0 };
successor_mut.left = std::mem::take(&mut node_mut.left);
successor_mut.right = std::mem::take(&mut node_mut.right);
successor_mut.balance_factor = current.balance_factor() - if shrank { 1 } else { 0 };
successor_mut.left = std::mem::take(&mut current.left);
successor_mut.right = std::mem::take(&mut current.right);

Self::rebalance(&mut successor, resolver)?;

let shrank = node_mut.balance_factor().abs() == 1 && successor.balance_factor == 0;
let successor_node = resolver.resolve(&successor)?;
let shrank = current.balance_factor().abs() == 1 && successor_node.balance_factor == 0;
Ok((successor, shrank))
}

#[inline]
/// A mutable reference to the right branch.
pub(super) fn right_mut(&mut self) -> &mut Tree {
pub(super) fn right_mut(&mut self) -> &mut Tree<Id> {
self.invalidate_hash();
&mut self.right
}

#[inline]
/// An immutable reference to the right branch.
pub(super) fn right_ref(&self) -> &Tree {
pub(super) fn right_ref(&self) -> &Tree<Id> {
&self.right
}

Expand All @@ -242,9 +244,9 @@ impl Node {
/// - The minimum [`Tree`]'s right child, if it hasn't been moved to its new position.
/// - True if this [`Node`]'s subtree has shrunk in size.
pub(super) fn take_min(
node: &mut Arc<Node>,
resolver: &mut impl Resolver<Arc<Node>, Node>,
) -> Result<(Tree, Tree, bool), OperationalError> {
node: &mut Id,
resolver: &mut impl NodeResolver<Id>,
) -> Result<(Tree<Id>, Tree<Id>, bool), OperationalError> {
let node_mut = resolver.resolve_mut(node)?;

let old_node_bf = node_mut.balance_factor();
Expand Down Expand Up @@ -277,7 +279,7 @@ impl Node {
key: &Key,
offset: usize,
data: impl FnOnce(&mut Value),
resolver: &mut impl Resolver<Arc<Node>, Node>,
resolver: &mut impl NodeResolver<Id>,
) -> Result<bool, OperationalError> {
// SAFETY: The default recursion limit in Rust is 128
// see: <https://doc.rust-lang.org/reference/attributes/limits.html#r-attributes.limits.recursion_limit.syntax>
Expand Down Expand Up @@ -341,8 +343,8 @@ impl Node {
/// Assumes this [`Node`]'s balance factor is 2 and the right [`Node`]'s balance factor is +1
/// or 0.
fn rotate_left(
node: &mut Arc<Node>,
resolver: &mut impl Resolver<Arc<Node>, Node>,
node: &mut Id,
resolver: &mut impl NodeResolver<Id>,
) -> Result<(), OperationalError> {
let node_mut = resolver.resolve_mut(node)?;
let mut right = node_mut
Expand Down Expand Up @@ -405,8 +407,8 @@ impl Node {
///
/// Assumes this [`Node`]'s balance factor is -2 and the left [Node]'s balance factor is +1.
fn rotate_left_right(
node: &mut Arc<Node>,
resolver: &mut impl Resolver<Arc<Node>, Node>,
node: &mut Id,
resolver: &mut impl NodeResolver<Id>,
) -> Result<(), OperationalError> {
let node_mut = resolver.resolve_mut(node)?;

Expand All @@ -421,14 +423,14 @@ impl Node {
.take()
.expect("Left's right child must exist for the left rotation of the left node");

let left_right_mut = resolver.resolve_mut(&mut left_right)?;

// From the `rotate_left` derivation, the first rotation does:
// new_A_bf_1 = old_A_bf - 1 + std::cmp::min(-A.right.balance_factor, 0)
// As this function assumes old_A_bf is +1:
// new_A_bf_1 = std::cmp::min(-A.right.balance_factor, 0)
// The second rotation doesn't mutate A's subtree, so the final balance factor is:
left_mut.balance_factor = std::cmp::min(-left_right.balance_factor, 0);

let left_right_mut = resolver.resolve_mut(&mut left_right)?;
left_mut.balance_factor = std::cmp::min(-left_right_mut.balance_factor, 0);

// B's right child is between B and B, it's moved to node's left
node_mut.left = left_right_mut.right.take().into();
Expand Down Expand Up @@ -468,8 +470,8 @@ impl Node {
/// Assumes this [`Node`]'s balance factor is -2 and the left [`Node`]'s balance factor is -1
/// or 0.
fn rotate_right(
node: &mut Arc<Node>,
resolver: &mut impl Resolver<Arc<Node>, Node>,
node: &mut Id,
resolver: &mut impl NodeResolver<Id>,
) -> Result<(), OperationalError> {
let node_mut = resolver.resolve_mut(node).expect("Node must exist.");
let mut left = node_mut
Expand Down Expand Up @@ -532,8 +534,8 @@ impl Node {
///
/// Assumes this [`Node`]'s balance factor is +2 and the left [Node]'s balance factor is -1.
fn rotate_right_left(
node: &mut Arc<Node>,
resolver: &mut impl Resolver<Arc<Node>, Node>,
node: &mut Id,
resolver: &mut impl NodeResolver<Id>,
) -> Result<(), OperationalError> {
let node_mut = resolver.resolve_mut(node)?;
let mut right = node_mut
Expand All @@ -547,14 +549,14 @@ impl Node {
.take()
.expect("Right's left child must exist for the right rotation of the right node");

let right_left_mut = resolver.resolve_mut(&mut right_left)?;

// From the `rotate_right` derivation, the first rotation does:
// new_A_bf_1 = old_A_bf + 1 + std::cmp::max(0, -A.left.balance_factor)
// As this function assumes old_A_bf is -1:
// new_A_bf_1 = std::cmp::max(0, -A.left.balance_factor)
// The second rotation doesn't mutate A's subtree, so the final balance factor is:
right_mut.balance_factor = std::cmp::max(0, -right_left.balance_factor);

let right_left_mut = resolver.resolve_mut(&mut right_left)?;
right_mut.balance_factor = std::cmp::max(0, -right_left_mut.balance_factor);

// B's left child is between node and B, it's moved to node's right
node_mut.right = right_left_mut.left.take().into();
Expand Down Expand Up @@ -583,9 +585,9 @@ impl Node {
///
/// If the hash has been cached, the memo is returned. Otherwise, the hash is calculated and
/// cached.
pub(crate) fn hash<'a>(
node: &'a Arc<Node>,
resolver: &impl Resolver<Arc<Node>, Node>,
pub(crate) fn hash<'a, Id: Clone>(
node: &'a Id,
resolver: &impl NodeResolver<Id>,
) -> Result<&'a Hash, OperationalError> {
let resolved = resolver.resolve(node)?;

Expand All @@ -601,7 +603,7 @@ pub(crate) fn hash<'a>(
}

#[cfg(test)]
impl Node {
impl<Id: Clone + Debug> Node<Id> {
#[inline]
/// The data stored in the [`Node`].
pub(crate) fn data(&self) -> &Value {
Expand All @@ -610,9 +612,9 @@ impl Node {

/// The data stored in a [`Node`] within the subtree of this [`Node`] with a given [`Key`] .
pub(super) fn get<'a>(
mut node: &'a Arc<Node>,
mut node: &'a Id,
key: &Key,
resolver: &impl Resolver<Arc<Node>, Node>,
resolver: &impl NodeResolver<Id>,
) -> Result<Option<&'a Value>, OperationalError> {
loop {
let resolved_node = resolver.resolve(node)?;
Expand All @@ -637,7 +639,7 @@ impl Node {
/// Returns true if the balance factors stored in the [`Node`]'s subtree are correct.
pub(super) fn has_correct_balance_factors(
&self,
resolver: &impl Resolver<Arc<Node>, Node>,
resolver: &impl NodeResolver<Id>,
) -> Result<bool, OperationalError> {
let left_height = self.left_ref().height(resolver)?;
let right_height = self.right_ref().height(resolver)?;
Expand All @@ -655,10 +657,7 @@ impl Node {
}

/// Returns the height of this [`Node`]'s subtree.
pub(super) fn height(
&self,
resolver: &impl Resolver<Arc<Node>, Node>,
) -> Result<u32, OperationalError> {
pub(super) fn height(&self, resolver: &impl NodeResolver<Id>) -> Result<u32, OperationalError> {
let left_height = self.left_ref().height(resolver)?;
let right_height = self.right_ref().height(resolver)?;
Ok(1 + std::cmp::max(left_height, right_height))
Expand All @@ -667,7 +666,7 @@ impl Node {
/// Returns true if this [`Node`]'s subtree is balanced.
pub(super) fn is_balanced(
&self,
resolver: &impl Resolver<Arc<Node>, Node>,
resolver: &impl NodeResolver<Id>,
) -> Result<bool, OperationalError> {
let balance_factor = self.balance_factor();
if balance_factor.abs() > 1 {
Expand All @@ -684,7 +683,7 @@ impl Node {
&self,
min: &Key,
max: &Key,
resolver: &impl Resolver<Arc<Node>, Node>,
resolver: &impl NodeResolver<Id>,
) -> Result<bool, OperationalError> {
if self.key() < min || self.key() > max {
return Ok(false);
Expand Down
Loading
Loading