diff --git a/src/hashes.ts b/src/hashes.ts index fcda8b6..763be40 100644 --- a/src/hashes.ts +++ b/src/hashes.ts @@ -4,9 +4,16 @@ import { BytesLike, HexString, concat, compare } from './bytes'; export type LeafHash = (leaf: T) => HexString; export type NodeHash = (left: BytesLike, right: BytesLike) => HexString; +export type Encoder = { + encode: (types: string[], values: any[]) => string; +}; -export function standardLeafHash(types: string[], value: T): HexString { - return keccak256(keccak256(defaultAbiCoder.encode(types, value))); +export function standardLeafHash( + types: string[], + value: T, + encoder: Encoder = defaultAbiCoder, +): HexString { + return keccak256(keccak256(encoder.encode(types, value))); } export function standardNodeHash(a: BytesLike, b: BytesLike): HexString { diff --git a/src/standard.test.ts b/src/standard.test.ts index 6d7f404..8303b61 100644 --- a/src/standard.test.ts +++ b/src/standard.test.ts @@ -3,6 +3,7 @@ import { HashZero as zero } from '@ethersproject/constants'; import { keccak256 } from '@ethersproject/keccak256'; import { StandardMerkleTree } from './standard'; import { InvalidArgumentError, InvariantError } from './utils/errors'; +import { defaultAbiCoder } from '@ethersproject/abi'; fc.configureGlobal({ numRuns: process.env.CI ? 5000 : 100 }); @@ -149,3 +150,28 @@ test('reject malformed tree dump', t => { new InvariantError('Merkle tree is invalid'), ); }); + +const customEncoderLeaf = fc.tuple( + fc.uint8Array({ minLength: 1, maxLength: 1 }), + fc.uint8Array({ minLength: 1, maxLength: 1 }), +); +const customEncoderLeaves = fc.array(customEncoderLeaf, { minLength: 1, maxLength: 1000 }); + +testProp('custom encoder', [customEncoderLeaves], (t, customEncoderLeaves) => { + const customEncoder = { + encode: (types: string[], values: any[]) => { + return defaultAbiCoder.encode(types.slice().reverse(), values); + }, + }; + + const customEncoderEncoding = ['bytes', 'bytes1']; + + const customEncoderTree = StandardMerkleTree.of(customEncoderLeaves, customEncoderEncoding, {}, customEncoder); + const reversedEncodingTree = StandardMerkleTree.of(customEncoderLeaves, customEncoderEncoding.slice().reverse()); + + t.deepEqual(customEncoderTree.root, reversedEncodingTree.root); + + const defaultEncodingTree = StandardMerkleTree.of(customEncoderLeaves, customEncoderEncoding); + + t.notDeepEqual(customEncoderTree.root, defaultEncodingTree.root); +}); diff --git a/src/standard.ts b/src/standard.ts index c69488d..073052f 100644 --- a/src/standard.ts +++ b/src/standard.ts @@ -1,8 +1,9 @@ +import { defaultAbiCoder } from '@ethersproject/abi'; import { BytesLike, HexString, toHex } from './bytes'; import { MultiProof, processProof, processMultiProof } from './core'; import { MerkleTreeData, MerkleTreeImpl } from './merkletree'; import { MerkleTreeOptions } from './options'; -import { standardLeafHash } from './hashes'; +import { Encoder, standardLeafHash } from './hashes'; import { validateArgument } from './utils/errors'; export interface StandardMerkleTreeData extends MerkleTreeData { @@ -15,44 +16,58 @@ export class StandardMerkleTree extends MerkleTreeImpl { protected readonly tree: HexString[], protected readonly values: StandardMerkleTreeData['values'], protected readonly leafEncoding: string[], + protected readonly encoder: Encoder = defaultAbiCoder, ) { - super(tree, values, leaf => standardLeafHash(leafEncoding, leaf)); + super(tree, values, leaf => standardLeafHash(leafEncoding, leaf, encoder)); } static of( values: T[], leafEncoding: string[], options: MerkleTreeOptions = {}, + encoder: Encoder = defaultAbiCoder, ): StandardMerkleTree { // use default nodeHash (standardNodeHash) - const [tree, indexedValues] = MerkleTreeImpl.prepare(values, options, leaf => standardLeafHash(leafEncoding, leaf)); - return new StandardMerkleTree(tree, indexedValues, leafEncoding); + const [tree, indexedValues] = MerkleTreeImpl.prepare(values, options, leaf => + standardLeafHash(leafEncoding, leaf, encoder), + ); + return new StandardMerkleTree(tree, indexedValues, leafEncoding, encoder); } - static load(data: StandardMerkleTreeData): StandardMerkleTree { + static load( + data: StandardMerkleTreeData, + encoder: Encoder = defaultAbiCoder, + ): StandardMerkleTree { validateArgument(data.format === 'standard-v1', `Unknown format '${data.format}'`); validateArgument(data.leafEncoding !== undefined, 'Expected leaf encoding'); - const tree = new StandardMerkleTree(data.tree, data.values, data.leafEncoding); + const tree = new StandardMerkleTree(data.tree, data.values, data.leafEncoding, encoder); tree.validate(); return tree; } - static verify(root: BytesLike, leafEncoding: string[], leaf: T, proof: BytesLike[]): boolean { + static verify( + root: BytesLike, + leafEncoding: string[], + leaf: T, + proof: BytesLike[], + encoder: Encoder = defaultAbiCoder, + ): boolean { // use default nodeHash (standardNodeHash) for processProof - return toHex(root) === processProof(standardLeafHash(leafEncoding, leaf), proof); + return toHex(root) === processProof(standardLeafHash(leafEncoding, leaf, encoder), proof); } static verifyMultiProof( root: BytesLike, leafEncoding: string[], multiproof: MultiProof, + encoder: Encoder = defaultAbiCoder, ): boolean { // use default nodeHash (standardNodeHash) for processMultiProof return ( toHex(root) === processMultiProof({ - leaves: multiproof.leaves.map(leaf => standardLeafHash(leafEncoding, leaf)), + leaves: multiproof.leaves.map(leaf => standardLeafHash(leafEncoding, leaf, encoder)), proof: multiproof.proof, proofFlags: multiproof.proofFlags, })