Skip to content

feat: Merkle tree primitives as outlined in spec #135

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions src/backends/plonky2/basetypes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,46 @@ pub fn hash_fields(input: &[F]) -> Hash {
Hash(PoseidonHash::hash_no_pad(&input).elements)
}

/// Hash function for key-value pairs. Different branch pair hashes to
/// mitigate fake proofs.
pub fn kv_hash(key: &Value, value: Option<Value>) -> Hash {
value
.map(|v| {
Hash(
PoseidonHash::hash_no_pad(
&[key.0.to_vec(), v.0.to_vec(), vec![GoldilocksField(1)]].concat(),
)
.elements,
)
})
.unwrap_or(Hash([GoldilocksField(0); 4]))
}

// NOTE 1: think if maybe the length of the returned vector can be <256
// (8*bytes.len()), so that we can do fewer iterations. For example, if the
// tree.max_depth is set to 20, we just need 20 iterations of the loop, not 256.
// NOTE 2: which approach do we take with keys that are longer than the
// max-depth? ie, what happens when two keys share the same path for more bits
// than the max_depth?
/// returns the path (bit decomposition) of the given key
pub fn keypath(max_depth: usize, k: Value) -> Result<Vec<bool>> {
let bytes = k.to_bytes();
if max_depth > 8 * bytes.len() {
// note that our current keys are of Value type, which are 4 Goldilocks
// field elements, ie ~256 bits, therefore the max_depth can not be
// bigger than 256.
Err(anyhow!(
"key too short (key length: {}) for the max_depth: {}",
8 * bytes.len(),
max_depth
))
} else {
Ok((0..max_depth)
.map(|n| bytes[n / 8] & (1 << (n % 8)) != 0)
.collect())
}
}

impl From<Value> for Hash {
fn from(v: Value) -> Self {
Hash(v.0)
Expand Down
2 changes: 1 addition & 1 deletion src/backends/plonky2/mock_main/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,7 @@ impl Pod for MockMainPod {
self.operations[i]
.deref(&self.statements[..input_statement_offset + i])
.unwrap()
.check(&self.params, &s.clone().try_into().unwrap())
.check_and_print(&self.params, &s.clone().try_into().unwrap())
})
.collect::<Result<Vec<_>>>()
.unwrap();
Expand Down
13 changes: 13 additions & 0 deletions src/backends/plonky2/mock_main/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,19 @@ impl TryFrom<Statement> for middleware::Statement {
(NP::NotContains, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), None), 2) => {
S::NotContains(ak1, ak2)
}
(NP::Branches, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => {
S::Branches(ak1, ak2, ak3)
}
(NP::Leaf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => {
S::Leaf(ak1, ak2, ak3)
}
(NP::IsNullTree, (Some(SA::Key(ak)), None, None), 1) => S::IsNullTree(ak),
(NP::GoesLeft, (Some(SA::Key(ak)), Some(SA::Literal(depth)), None), 2) => {
S::GoesLeft(ak, depth)
}
(NP::GoesRight, (Some(SA::Key(ak)), Some(SA::Literal(depth)), None), 2) => {
S::GoesRight(ak, depth)
}
(NP::SumOf, (Some(SA::Key(ak1)), Some(SA::Key(ak2)), Some(SA::Key(ak3))), 3) => {
S::SumOf(ak1, ak2, ak3)
}
Expand Down
68 changes: 30 additions & 38 deletions src/backends/plonky2/primitives/merkletree.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
//! Module that implements the MerkleTree specified at
//! https://0xparc.github.io/pod2/merkletree.html .
use anyhow::{anyhow, Result};
use plonky2::field::goldilocks_field::GoldilocksField;
use std::collections::HashMap;
use std::fmt;
use std::iter::IntoIterator;

use crate::backends::counter;
use crate::backends::plonky2::basetypes::{hash_fields, Hash, Value, F, NULL};
use crate::middleware::{keypath, kv_hash};

/// Implements the MerkleTree specified at
/// https://0xparc.github.io/pod2/merkletree.html
#[derive(Clone, Debug)]
pub struct MerkleTree {
max_depth: usize,
root: Node,
pub max_depth: usize,
pub root: Node,
}

impl MerkleTree {
Expand Down Expand Up @@ -174,14 +174,6 @@ impl MerkleTree {
}
}

/// Hash function for key-value pairs. Different branch pair hashes to
/// mitigate fake proofs.
pub fn kv_hash(key: &Value, value: Option<Value>) -> Hash {
value
.map(|v| hash_fields(&[key.0.to_vec(), v.0.to_vec(), vec![GoldilocksField(1)]].concat()))
.unwrap_or(Hash([GoldilocksField(0); 4]))
}

impl<'a> IntoIterator for &'a MerkleTree {
type Item = (&'a Value, &'a Value);
type IntoIter = Iter<'a>;
Expand Down Expand Up @@ -256,7 +248,7 @@ impl MerkleProof {
}

#[derive(Clone, Debug)]
enum Node {
pub enum Node {
None,
Leaf(Leaf),
Intermediate(Intermediate),
Expand Down Expand Up @@ -293,7 +285,7 @@ impl fmt::Display for Node {
}

impl Node {
fn is_empty(&self) -> bool {
pub fn is_empty(&self) -> bool {
match self {
Self::None => true,
Self::Leaf(_l) => false,
Expand All @@ -307,14 +299,38 @@ impl Node {
Self::Intermediate(n) => n.compute_hash(),
}
}
fn hash(&self) -> Hash {
pub fn hash(&self) -> Hash {
match self {
Self::None => NULL,
Self::Leaf(l) => l.hash(),
Self::Intermediate(n) => n.hash(),
}
}

pub fn left(&self) -> Option<&Box<Node>> {
match self {
Self::None => None,
Self::Leaf(_l) => None,
Self::Intermediate(Intermediate {
hash: _h,
left: l,
right: _r
}) => Some(l),
}
}

pub fn right(&self) -> Option<&Box<Node>> {
match self {
Self::None => None,
Self::Leaf(_l) => None,
Self::Intermediate(Intermediate {
hash: _h,
left: _l,
right: r
}) => Some(r),
}
}

/// Goes down from the current node until it encounters a terminal node,
/// viz. a leaf or empty node, or until it reaches the maximum depth. The
/// `siblings` parameter is used to store the siblings while going down to
Expand Down Expand Up @@ -513,30 +529,6 @@ impl Leaf {
}
}

// NOTE 1: think if maybe the length of the returned vector can be <256
// (8*bytes.len()), so that we can do fewer iterations. For example, if the
// tree.max_depth is set to 20, we just need 20 iterations of the loop, not 256.
// NOTE 2: which approach do we take with keys that are longer than the
// max-depth? ie, what happens when two keys share the same path for more bits
// than the max_depth?
/// returns the path of the given key
fn keypath(max_depth: usize, k: Value) -> Result<Vec<bool>> {
let bytes = k.to_bytes();
if max_depth > 8 * bytes.len() {
// note that our current keys are of Value type, which are 4 Goldilocks
// field elements, ie ~256 bits, therefore the max_depth can not be
// bigger than 256.
return Err(anyhow!(
"key to short (key length: {}) for the max_depth: {}",
8 * bytes.len(),
max_depth
));
}
Ok((0..max_depth)
.map(|n| bytes[n / 8] & (1 << (n % 8)) != 0)
.collect())
}

pub struct Iter<'a> {
state: Vec<&'a Node>,
}
Expand Down
2 changes: 1 addition & 1 deletion src/constants.rs
Original file line number Diff line number Diff line change
@@ -1 +1 @@
pub const MAX_DEPTH: usize = 32;
pub const MAX_DEPTH: usize = 256;
86 changes: 85 additions & 1 deletion src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@ use std::collections::HashMap;
use std::convert::From;
use std::{fmt, hash as h};

use crate::backends::plonky2::primitives::merkletree::MerkleProof;
use crate::middleware::{
self,
containers::{Array, Dictionary, Set},
hash_str, Hash, MainPodInputs, NativeOperation, NativePredicate, Params, PodId, PodProver,
PodSigner, SELF,
};
use crate::middleware::{OperationType, Predicate, KEY_SIGNER, KEY_TYPE};
use crate::middleware::{kv_hash, OperationType, Predicate, KEY_SIGNER, KEY_TYPE};
use crate::op;

mod custom;
mod operation;
Expand Down Expand Up @@ -437,6 +439,11 @@ impl MainPodBuilder {
},
ContainsFromEntries => self.op_args_entries(public, args)?,
NotContainsFromEntries => self.op_args_entries(public, args)?,
BranchesFromEntries => self.op_args_entries(public, args)?,
LeafFromEntries => self.op_args_entries(public, args)?,
IsNullTree => self.op_args_entries(public, args)?,
GoesLeft => self.op_args_entries(public, args)?,
GoesRight => self.op_args_entries(public, args)?,
SumOf => match (args[0].clone(), args[1].clone(), args[2].clone()) {
(
OperationArg::Statement(Statement(
Expand Down Expand Up @@ -920,6 +927,12 @@ pub mod build_utils {
(not_contains, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::NotContainsFromEntries),
crate::op_args!($($arg),*)) };
(branches, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::BranchesFromEntries),
crate::op_args!($($arg),*)) };
(leaf, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::LeafFromEntries),
crate::op_args!($($arg),*)) };
(sum_of, $($arg:expr),+) => { crate::frontend::Operation(
crate::middleware::OperationType::Native(crate::middleware::NativeOperation::SumOf),
crate::op_args!($($arg),*)) };
Expand All @@ -940,6 +953,7 @@ pub mod tests {
use super::*;
use crate::backends::plonky2::mock_main::MockProver;
use crate::backends::plonky2::mock_signed::MockSigner;
use crate::backends::plonky2::primitives::merkletree::MerkleTree;
use crate::examples::{
eth_dos_pod_builder, eth_friend_signed_pod_builder, great_boy_pod_full_flow,
tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_sign_pod_builders,
Expand Down Expand Up @@ -1152,4 +1166,74 @@ pub mod tests {
println!("{}", builder);
println!("{}", false_pod);
}

#[test]
fn test_merkle_proofs() -> Result<()> {
let mut kvs = HashMap::new();
for i in 0..8 {
if i == 1 {
continue;
}
kvs.insert(middleware::Value::from(i), middleware::Value::from(1000 + i));
}
let key = middleware::Value::from(13);
let value = middleware::Value::from(1013);
kvs.insert(key, value);

let tree = MerkleTree::new(32, &kvs)?;
// when printing the tree, it should print the same tree as in
// https://0xparc.github.io/pod2/merkletree.html#example-2
println!("{}", tree);

println!("{}", tree.root());

let root: Hash = tree.root();
let left: Hash = (*tree.root.left().unwrap().clone()).hash();
let right: Hash = (*tree.root.right().unwrap().clone()).hash();

let params = Params::default();
let mut signed_builder = SignedPodBuilder::new(&params);

let mut builder = MainPodBuilder::new(&params);
let introduce_root_op = Operation(
OperationType::Native(NativeOperation::NewEntry),
vec![
OperationArg::Entry("root".into(), Value::Raw(middleware::Value::from(root))),
]
);
let introduce_left_op = Operation(
OperationType::Native(NativeOperation::NewEntry),
vec![
OperationArg::Entry("left".into(), Value::Raw(middleware::Value::from(left))),
]
);
let introduce_right_op = Operation(
OperationType::Native(NativeOperation::NewEntry),
vec![
OperationArg::Entry("right".into(), Value::Raw(middleware::Value::from(right))),
]
);
let st1 = builder.op(false, introduce_root_op).unwrap();
let st2 = builder.op(false, introduce_left_op).unwrap();
let st3 = builder.op(false, introduce_right_op).unwrap();

// verify Branches statement
let branches_op = Operation(
OperationType::Native(NativeOperation::BranchesFromEntries),
vec![
OperationArg::Statement(st1),
OperationArg::Statement(st2),
OperationArg::Statement(st3),
]
);

let _branches_st = builder.op(true, branches_op).unwrap();

let mut prover = MockProver {};
let pod = builder.prove(&mut prover, &params).unwrap();
print!("{}", pod);
assert_eq!(pod.pod.verify(), true);

Ok(())
}
}
Loading