Skip to content

Commit 29f611e

Browse files
Amxxernestognw
andauthored
Support custom node hash in SimpleMerkleTree (#39)
Co-authored-by: ernestognw <[email protected]>
1 parent 6ab2cfb commit 29f611e

9 files changed

+399
-99
lines changed

src/core.ts

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1-
import { keccak256 } from '@ethersproject/keccak256';
2-
import { BytesLike, HexString, toHex, toBytes, concat, compare } from './bytes';
1+
import { BytesLike, HexString, toHex, toBytes, compare } from './bytes';
2+
import { NodeHash, standardNodeHash } from './hashes';
33
import { invariant, throwError, validateArgument } from './utils/errors';
44

5-
const hashPair = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare)));
6-
75
const leftChildIndex = (i: number) => 2 * i + 1;
86
const rightChildIndex = (i: number) => 2 * i + 2;
97
const parentIndex = (i: number) => (i > 0 ? Math.floor((i - 1) / 2) : throwError('Root has no parent'));
@@ -18,7 +16,7 @@ const checkLeafNode = (tree: unknown[], i: number) => void (isLeafNode(tree, i)
1816
const checkValidMerkleNode = (node: BytesLike) =>
1917
void (isValidMerkleNode(node) || throwError('Merkle tree nodes must be Uint8Array of length 32'));
2018

21-
export function makeMerkleTree(leaves: BytesLike[]): HexString[] {
19+
export function makeMerkleTree(leaves: BytesLike[], nodeHash: NodeHash = standardNodeHash): HexString[] {
2220
leaves.forEach(checkValidMerkleNode);
2321

2422
validateArgument(leaves.length !== 0, 'Expected non-zero number of leaves');
@@ -29,7 +27,7 @@ export function makeMerkleTree(leaves: BytesLike[]): HexString[] {
2927
tree[tree.length - 1 - i] = toHex(leaf);
3028
}
3129
for (let i = tree.length - 1 - leaves.length; i >= 0; i--) {
32-
tree[i] = hashPair(tree[leftChildIndex(i)]!, tree[rightChildIndex(i)]!);
30+
tree[i] = nodeHash(tree[leftChildIndex(i)]!, tree[rightChildIndex(i)]!);
3331
}
3432

3533
return tree;
@@ -46,11 +44,11 @@ export function getProof(tree: BytesLike[], index: number): HexString[] {
4644
return proof;
4745
}
4846

49-
export function processProof(leaf: BytesLike, proof: BytesLike[]): HexString {
47+
export function processProof(leaf: BytesLike, proof: BytesLike[], nodeHash: NodeHash = standardNodeHash): HexString {
5048
checkValidMerkleNode(leaf);
5149
proof.forEach(checkValidMerkleNode);
5250

53-
return toHex(proof.reduce(hashPair, leaf));
51+
return toHex(proof.reduce(nodeHash, leaf));
5452
}
5553

5654
export interface MultiProof<T, L = T> {
@@ -68,7 +66,7 @@ export function getMultiProof(tree: BytesLike[], indices: number[]): MultiProof<
6866
'Cannot prove duplicated index',
6967
);
7068

71-
const stack = indices.concat(); // copy
69+
const stack = Array.from(indices); // copy
7270
const proof = [];
7371
const proofFlags = [];
7472

@@ -98,7 +96,7 @@ export function getMultiProof(tree: BytesLike[], indices: number[]): MultiProof<
9896
};
9997
}
10098

101-
export function processMultiProof(multiproof: MultiProof<BytesLike>): HexString {
99+
export function processMultiProof(multiproof: MultiProof<BytesLike>, nodeHash: NodeHash = standardNodeHash): HexString {
102100
multiproof.leaves.forEach(checkValidMerkleNode);
103101
multiproof.proof.forEach(checkValidMerkleNode);
104102

@@ -111,22 +109,22 @@ export function processMultiProof(multiproof: MultiProof<BytesLike>): HexString
111109
'Provided leaves and multiproof are not compatible',
112110
);
113111

114-
const stack = multiproof.leaves.concat(); // copy
115-
const proof = multiproof.proof.concat(); // copy
112+
const stack = Array.from(multiproof.leaves); // copy
113+
const proof = Array.from(multiproof.proof); // copy
116114

117115
for (const flag of multiproof.proofFlags) {
118116
const a = stack.shift();
119117
const b = flag ? stack.shift() : proof.shift();
120118
invariant(a !== undefined && b !== undefined);
121-
stack.push(hashPair(a, b));
119+
stack.push(nodeHash(a, b));
122120
}
123121

124122
invariant(stack.length + proof.length === 1);
125123

126124
return toHex(stack.pop() ?? proof.shift()!);
127125
}
128126

129-
export function isValidMerkleTree(tree: BytesLike[]): boolean {
127+
export function isValidMerkleTree(tree: BytesLike[], nodeHash: NodeHash = standardNodeHash): boolean {
130128
for (const [i, node] of tree.entries()) {
131129
if (!isValidMerkleNode(node)) {
132130
return false;
@@ -139,7 +137,7 @@ export function isValidMerkleTree(tree: BytesLike[]): boolean {
139137
if (l < tree.length) {
140138
return false;
141139
}
142-
} else if (compare(node, hashPair(tree[l]!, tree[r]!))) {
140+
} else if (compare(node, nodeHash(tree[l]!, tree[r]!))) {
143141
return false;
144142
}
145143
}

src/hashes.ts

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import { defaultAbiCoder } from '@ethersproject/abi';
2+
import { keccak256 } from '@ethersproject/keccak256';
3+
import { BytesLike, HexString, concat, compare } from './bytes';
4+
5+
export type LeafHash<T> = (leaf: T) => HexString;
6+
export type NodeHash = (left: BytesLike, right: BytesLike) => HexString;
7+
8+
export function standardLeafHash<T extends any[]>(types: string[], value: T): HexString {
9+
return keccak256(keccak256(defaultAbiCoder.encode(types, value)));
10+
}
11+
12+
export function standardNodeHash(a: BytesLike, b: BytesLike): HexString {
13+
return keccak256(concat([a, b].sort(compare)));
14+
}

src/merkletree.ts

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import {
1212
} from './core';
1313

1414
import { MerkleTreeOptions, defaultOptions } from './options';
15+
import { LeafHash, NodeHash } from './hashes';
1516
import { validateArgument, invariant } from './utils/errors';
1617

1718
export interface MerkleTreeData<T> {
@@ -40,7 +41,8 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
4041
protected constructor(
4142
protected readonly tree: HexString[],
4243
protected readonly values: MerkleTreeData<T>['values'],
43-
public readonly leafHash: MerkleTree<T>['leafHash'],
44+
public readonly leafHash: LeafHash<T>,
45+
protected readonly nodeHash?: NodeHash,
4446
) {
4547
validateArgument(
4648
values.every(({ value }) => typeof value != 'number'),
@@ -52,7 +54,8 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
5254
protected static prepare<T>(
5355
values: T[],
5456
options: MerkleTreeOptions = {},
55-
leafHash: MerkleTree<T>['leafHash'],
57+
leafHash: LeafHash<T>,
58+
nodeHash?: NodeHash,
5659
): [tree: HexString[], indexedValues: MerkleTreeData<T>['values']] {
5760
const sortLeaves = options.sortLeaves ?? defaultOptions.sortLeaves;
5861
const hashedValues = values.map((value, valueIndex) => ({
@@ -65,7 +68,10 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
6568
hashedValues.sort((a, b) => compare(a.hash, b.hash));
6669
}
6770

68-
const tree = makeMerkleTree(hashedValues.map(v => v.hash));
71+
const tree = makeMerkleTree(
72+
hashedValues.map(v => v.hash),
73+
nodeHash,
74+
);
6975

7076
const indexedValues = values.map(value => ({ value, treeIndex: 0 }));
7177
for (const [leafIndex, { valueIndex }] of hashedValues.entries()) {
@@ -93,7 +99,7 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
9399

94100
validate(): void {
95101
this.values.forEach((_, i) => this._validateValueAt(i));
96-
invariant(isValidMerkleTree(this.tree), 'Merkle tree is invalid');
102+
invariant(isValidMerkleTree(this.tree, this.nodeHash), 'Merkle tree is invalid');
97103
}
98104

99105
leafLookup(leaf: T): number {
@@ -171,10 +177,10 @@ export abstract class MerkleTreeImpl<T> implements MerkleTree<T> {
171177
}
172178

173179
private _verify(leafHash: BytesLike, proof: BytesLike[]): boolean {
174-
return this.root === processProof(leafHash, proof);
180+
return this.root === processProof(leafHash, proof, this.nodeHash);
175181
}
176182

177183
private _verifyMultiProof(multiproof: MultiProof<BytesLike>): boolean {
178-
return this.root === processMultiProof(multiproof);
184+
return this.root === processMultiProof(multiproof, this.nodeHash);
179185
}
180186
}

src/simple.test.ts

Lines changed: 91 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,36 @@
11
import { test, testProp, fc } from '@fast-check/ava';
22
import { HashZero as zero } from '@ethersproject/constants';
3+
import { keccak256 } from '@ethersproject/keccak256';
34
import { SimpleMerkleTree } from './simple';
5+
import { BytesLike, HexString, concat, compare } from './bytes';
6+
7+
const reverseNodeHash = (a: BytesLike, b: BytesLike): HexString => keccak256(concat([a, b].sort(compare).reverse()));
8+
const otherNodeHash = (a: BytesLike, b: BytesLike): HexString => keccak256(reverseNodeHash(a, b)); // double hash
9+
410
import { toHex } from './bytes';
511
import { InvalidArgumentError, InvariantError } from './utils/errors';
612

713
const leaf = fc.uint8Array({ minLength: 32, maxLength: 32 }).map(toHex);
814
const leaves = fc.array(leaf, { minLength: 1 });
9-
const options = fc.record({ sortLeaves: fc.oneof(fc.constant(undefined), fc.boolean()) });
15+
const options = fc.record({
16+
sortLeaves: fc.oneof(fc.constant(undefined), fc.boolean()),
17+
nodeHash: fc.oneof(fc.constant(undefined), fc.constant(reverseNodeHash)),
18+
});
1019

11-
const tree = fc.tuple(leaves, options).map(([leaves, options]) => SimpleMerkleTree.of(leaves, options));
20+
const tree = fc
21+
.tuple(leaves, options)
22+
.chain(([leaves, options]) => fc.tuple(fc.constant(SimpleMerkleTree.of(leaves, options)), fc.constant(options)));
1223
const treeAndLeaf = fc.tuple(leaves, options).chain(([leaves, options]) =>
1324
fc.tuple(
1425
fc.constant(SimpleMerkleTree.of(leaves, options)),
26+
fc.constant(options),
1527
fc.nat({ max: leaves.length - 1 }).map(index => ({ value: leaves[index]!, index })),
1628
),
1729
);
1830
const treeAndLeaves = fc.tuple(leaves, options).chain(([leaves, options]) =>
1931
fc.tuple(
2032
fc.constant(SimpleMerkleTree.of(leaves, options)),
33+
fc.constant(options),
2134
fc
2235
.uniqueArray(fc.nat({ max: leaves.length - 1 }))
2336
.map(indices => indices.map(index => ({ value: leaves[index]!, index }))),
@@ -26,48 +39,64 @@ const treeAndLeaves = fc.tuple(leaves, options).chain(([leaves, options]) =>
2639

2740
fc.configureGlobal({ numRuns: process.env.CI ? 10000 : 100 });
2841

29-
testProp('generates a valid tree', [tree], (t, tree) => {
42+
testProp('generates a valid tree', [tree], (t, [tree]) => {
3043
t.notThrows(() => tree.validate());
3144
});
3245

33-
testProp('generates valid single proofs for all leaves', [treeAndLeaf], (t, [tree, { value: leaf, index }]) => {
34-
const proof1 = tree.getProof(index);
35-
const proof2 = tree.getProof(leaf);
36-
37-
t.deepEqual(proof1, proof2);
38-
t.true(tree.verify(index, proof1));
39-
t.true(tree.verify(leaf, proof1));
40-
t.true(SimpleMerkleTree.verify(tree.root, leaf, proof1));
41-
});
46+
testProp(
47+
'generates valid single proofs for all leaves',
48+
[treeAndLeaf],
49+
(t, [tree, options, { value: leaf, index }]) => {
50+
const proof1 = tree.getProof(index);
51+
const proof2 = tree.getProof(leaf);
52+
53+
t.deepEqual(proof1, proof2);
54+
t.true(tree.verify(index, proof1));
55+
t.true(tree.verify(leaf, proof1));
56+
t.true(SimpleMerkleTree.verify(tree.root, leaf, proof1, options.nodeHash));
57+
},
58+
);
4259

43-
testProp('rejects invalid proofs', [treeAndLeaf, tree], (t, [tree, { value: leaf }], otherTree) => {
44-
const proof = tree.getProof(leaf);
45-
t.false(otherTree.verify(leaf, proof));
46-
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof));
47-
});
60+
testProp(
61+
'rejects invalid proofs',
62+
[treeAndLeaf, tree],
63+
(t, [tree, options, { value: leaf }], [otherTree, otherOptions]) => {
64+
const proof = tree.getProof(leaf);
65+
t.false(otherTree.verify(leaf, proof));
66+
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof, options.nodeHash));
67+
t.false(SimpleMerkleTree.verify(otherTree.root, leaf, proof, otherOptions.nodeHash));
68+
},
69+
);
4870

49-
testProp('generates valid multiproofs', [treeAndLeaves], (t, [tree, indices]) => {
71+
testProp('generates valid multiproofs', [treeAndLeaves], (t, [tree, options, indices]) => {
5072
const proof1 = tree.getMultiProof(indices.map(e => e.index));
5173
const proof2 = tree.getMultiProof(indices.map(e => e.value));
5274

5375
t.deepEqual(proof1, proof2);
5476
t.true(tree.verifyMultiProof(proof1));
55-
t.true(SimpleMerkleTree.verifyMultiProof(tree.root, proof1));
77+
t.true(SimpleMerkleTree.verifyMultiProof(tree.root, proof1, options.nodeHash));
5678
});
5779

58-
testProp('rejects invalid multiproofs', [treeAndLeaves, tree], (t, [tree, indices], otherTree) => {
59-
const multiProof = tree.getMultiProof(indices.map(e => e.index));
60-
61-
t.false(otherTree.verifyMultiProof(multiProof));
62-
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof));
63-
});
80+
testProp(
81+
'rejects invalid multiproofs',
82+
[treeAndLeaves, tree],
83+
(t, [tree, options, indices], [otherTree, otherOptions]) => {
84+
const multiProof = tree.getMultiProof(indices.map(e => e.index));
85+
86+
t.false(otherTree.verifyMultiProof(multiProof));
87+
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof, options.nodeHash));
88+
t.false(SimpleMerkleTree.verifyMultiProof(otherTree.root, multiProof, otherOptions.nodeHash));
89+
},
90+
);
6491

6592
testProp(
6693
'renders tree representation',
6794
[leaves],
6895
(t, leaves) => {
6996
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true }).render());
7097
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false }).render());
98+
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseNodeHash }).render());
99+
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseNodeHash }).render());
71100
},
72101
{ numRuns: 1, seed: 0 },
73102
);
@@ -78,24 +107,34 @@ testProp(
78107
(t, leaves) => {
79108
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true }).dump());
80109
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false }).dump());
110+
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: true, nodeHash: reverseNodeHash }).dump());
111+
t.snapshot(SimpleMerkleTree.of(leaves, { sortLeaves: false, nodeHash: reverseNodeHash }).dump());
81112
},
82113
{ numRuns: 1, seed: 0 },
83114
);
84115

85-
testProp('dump and load', [tree], (t, tree) => {
86-
const recoveredTree = SimpleMerkleTree.load(tree.dump());
87-
recoveredTree.validate();
116+
testProp('dump and load', [tree], (t, [tree, options]) => {
117+
const dump = tree.dump();
118+
const recoveredTree = SimpleMerkleTree.load(dump, options.nodeHash);
119+
recoveredTree.validate(); // already done in load
88120

121+
t.is(dump.hash, options.nodeHash ? 'custom' : undefined);
89122
t.is(tree.root, recoveredTree.root);
90123
t.is(tree.render(), recoveredTree.render());
91124
t.deepEqual(tree.entries(), recoveredTree.entries());
92125
t.deepEqual(tree.dump(), recoveredTree.dump());
93126
});
94127

95-
testProp('reject out of bounds value index', [tree], (t, tree) => {
128+
testProp('reject out of bounds value index', [tree], (t, [tree]) => {
96129
t.throws(() => tree.getProof(-1), new InvalidArgumentError('Index out of bounds'));
97130
});
98131

132+
// We need at least 2 leaves for internal node hashing to come into play
133+
testProp('reject loading dump with wrong node hash', [fc.array(leaf, { minLength: 2 })], (t, leaves) => {
134+
const dump = SimpleMerkleTree.of(leaves, { nodeHash: reverseNodeHash }).dump();
135+
t.throws(() => SimpleMerkleTree.load(dump, otherNodeHash), new InvariantError('Merkle tree is invalid'));
136+
});
137+
99138
test('reject invalid leaf size', t => {
100139
const invalidLeaf = '0x000000000000000000000000000000000000000000000000000000000000000000';
101140
t.throws(() => SimpleMerkleTree.of([invalidLeaf]), {
@@ -116,22 +155,28 @@ test('reject unrecognized tree dump', t => {
116155
});
117156

118157
test('reject malformed tree dump', t => {
119-
const loadedTree1 = SimpleMerkleTree.load({
120-
format: 'simple-v1',
121-
tree: [zero],
122-
values: [
123-
{
124-
value: '0x0000000000000000000000000000000000000000000000000000000000000001',
125-
treeIndex: 0,
126-
},
127-
],
128-
});
129-
t.throws(() => loadedTree1.getProof(0), new InvariantError('Merkle tree does not contain the expected value'));
158+
t.throws(
159+
() =>
160+
SimpleMerkleTree.load({
161+
format: 'simple-v1',
162+
tree: [zero],
163+
values: [
164+
{
165+
value: '0x0000000000000000000000000000000000000000000000000000000000000001',
166+
treeIndex: 0,
167+
},
168+
],
169+
}),
170+
new InvariantError('Merkle tree does not contain the expected value'),
171+
);
130172

131-
const loadedTree2 = SimpleMerkleTree.load({
132-
format: 'simple-v1',
133-
tree: [zero, zero, zero],
134-
values: [{ value: zero, treeIndex: 2 }],
135-
});
136-
t.throws(() => loadedTree2.getProof(0), new InvariantError('Unable to prove value'));
173+
t.throws(
174+
() =>
175+
SimpleMerkleTree.load({
176+
format: 'simple-v1',
177+
tree: [zero, zero, zero],
178+
values: [{ value: zero, treeIndex: 2 }],
179+
}),
180+
new InvariantError('Merkle tree is invalid'),
181+
);
137182
});

0 commit comments

Comments
 (0)