Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unsafe IMT optimization #155

Draft
wants to merge 7 commits into
base: IMT-optimization
Choose a base branch
from
126 changes: 67 additions & 59 deletions contracts/data/IncrementalMerkleTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ library IncrementalMerkleTree {
using IncrementalMerkleTree for Tree;

struct Tree {
bytes32[][] nodes;
bytes32[][] __nodes;
}

/**
Expand All @@ -15,10 +15,8 @@ library IncrementalMerkleTree {
* @return treeSize size of tree
*/
function size(Tree storage t) internal view returns (uint256 treeSize) {
bytes32[][] storage nodes = t.nodes;

assembly {
mstore(0x00, nodes.slot)
mstore(0x00, t.slot)
treeSize := sload(keccak256(0x00, 0x20))
}
}
Expand All @@ -30,10 +28,8 @@ library IncrementalMerkleTree {
* @return treeHeight one-indexed height of tree
*/
function height(Tree storage t) internal view returns (uint256 treeHeight) {
bytes32[][] storage nodes = t.nodes;

assembly {
treeHeight := sload(nodes.slot)
treeHeight := sload(t.slot)
}
}

Expand All @@ -43,13 +39,10 @@ library IncrementalMerkleTree {
* @return hash root hash
*/
function root(Tree storage t) internal view returns (bytes32 hash) {
bytes32[][] storage nodes = t.nodes;

uint256 treeHeight = t.height();

if (treeHeight > 0) {
assembly {
mstore(0x00, nodes.slot)
assembly {
let treeHeight := sload(t.slot)
if gt(treeHeight, 0) {
mstore(0x00, t.slot)
mstore(0x00, add(keccak256(0x00, 0x20), sub(treeHeight, 1)))
hash := sload(keccak256(0x00, 0x20))
}
Expand All @@ -60,7 +53,15 @@ library IncrementalMerkleTree {
Tree storage t,
uint256 index
) internal view returns (bytes32 hash) {
hash = t.nodes[0][index];
if (index >= t.size()) {
new bytes32[](0)[1];
}

assembly {
mstore(0x00, t.slot)
mstore(0x00, keccak256(0x00, 0x20))
hash := sload(add(keccak256(0x00, 0x20), index))
}
}

/**
Expand All @@ -69,56 +70,60 @@ library IncrementalMerkleTree {
* @param hash to add
*/
function push(Tree storage t, bytes32 hash) internal {
unchecked {
// index to add to tree
uint256 updateIndex = t.size();
// index to add to tree
uint256 updateIndex = t.size();

// add new layer if tree is at capacity
// update stored tree size

if (updateIndex == (1 << t.height()) >> 1) {
t.nodes.push();
}
assembly {
mstore(0x00, t.slot)
sstore(keccak256(0x00, 0x20), add(updateIndex, 1))
}

// add new columns if rows are full
// add new layer if tree is at capacity

uint256 row;
uint256 col = updateIndex;
uint256 treeHeight = t.height();

while (col == t.nodes[row].length) {
t.nodes[row].push();
row++;
if (col == 0) break;
col >>= 1;
if (updateIndex == (1 << treeHeight) >> 1) {
// increment tree height in storage
assembly {
sstore(t.slot, add(treeHeight, 1))
}
}

// add hash to tree
// add hash to tree

t.set(updateIndex, hash);
}
t.set(updateIndex, hash);
}

function pop(Tree storage t) internal {
uint256 treeSize = t.size();

if (treeSize == 0) {
new bytes32[](0)[1];
}

unchecked {
// index to remove from tree
uint256 updateIndex = t.size() - 1;
uint256 updateIndex = treeSize - 1;

// remove columns if rows are too long
// update stored tree size

uint256 row;
uint256 col = updateIndex;

while (col != t.nodes[row].length) {
t.nodes[row].pop();
row++;
col >>= 1;
if (col == 0) break;
assembly {
mstore(0x00, t.slot)
sstore(keccak256(0x00, 0x20), updateIndex)
}

// if new tree is full, remove excess layer
// if no layer is removed, recalculate hashes

if (updateIndex == (1 << t.height()) >> 2) {
t.nodes.pop();
uint256 treeHeight = t.height();

if (updateIndex == (1 << treeHeight) >> 2) {
// decrement tree height in storage
assembly {
sstore(t.slot, sub(treeHeight, 1))
}
} else {
t.set(updateIndex - 1, t.at(updateIndex - 1));
}
Expand All @@ -132,25 +137,36 @@ library IncrementalMerkleTree {
* @param hash new hash to add
*/
function set(Tree storage t, uint256 index, bytes32 hash) internal {
_set(t.nodes, 0, index, t.size(), hash);
uint256 treeSize = t.size();

if (index >= treeSize) {
new bytes32[](0)[1];
}

_set(t, 0, index, treeSize, hash);
}

/**
* @notice update element in tree and recursively recalculate hashes
* @param nodes internal tree structure storage reference
* @param t Tree struct storage reference
* @param rowIndex index of current row to update
* @param colIndex index of current column to update
* @param rowLength length of row at rowIndex
* @param hash hash to store at current position
*/
function _set(
bytes32[][] storage nodes,
Tree storage t,
uint256 rowIndex,
uint256 colIndex,
uint256 rowLength,
bytes32 hash
) private {
bytes32[] storage row = nodes[rowIndex];
bytes32[] storage row;

assembly {
mstore(0x00, t.slot)
row.slot := add(keccak256(0x00, 0x20), rowIndex)
}

// store hash in array via assembly to avoid array length sload

Expand All @@ -165,7 +181,6 @@ library IncrementalMerkleTree {
if (colIndex & 1 == 1) {
// sibling is on the left
assembly {
mstore(0x00, row.slot)
let sibling := sload(
add(keccak256(0x00, 0x20), sub(colIndex, 1))
)
Expand All @@ -176,7 +191,6 @@ library IncrementalMerkleTree {
} else if (colIndex < rowLength - 1) {
// sibling is on the right (and sibling exists)
assembly {
mstore(0x00, row.slot)
let sibling := sload(
add(keccak256(0x00, 0x20), add(colIndex, 1))
)
Expand All @@ -185,14 +199,8 @@ library IncrementalMerkleTree {
hash := keccak256(0x00, 0x40)
}
}

_set(
nodes,
rowIndex + 1,
colIndex >> 1,
(rowLength + 1) >> 1,
hash
);
}

_set(t, rowIndex + 1, colIndex >> 1, (rowLength + 1) >> 1, hash);
}
}