From 76854a0e906c5d983cb75197996b7d8a91a91521 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Fri, 29 Aug 2025 23:04:45 +0200 Subject: [PATCH 01/13] xmss: add full spec --- src/lean_spec/subspecs/xmss/__init__.py | 14 +- src/lean_spec/subspecs/xmss/constants.py | 14 + src/lean_spec/subspecs/xmss/hypercube.py | 247 +++++++++++++ src/lean_spec/subspecs/xmss/interface.py | 170 ++++++++- src/lean_spec/subspecs/xmss/merkle_tree.py | 219 ++++++++++++ src/lean_spec/subspecs/xmss/prf.py | 135 ++++++++ src/lean_spec/subspecs/xmss/structures.py | 61 +++- src/lean_spec/subspecs/xmss/tweak_hash.py | 327 ++++++++++++++++++ .../lean_spec/subspecs/xmss/test_hypercube.py | 225 ++++++++++++ .../subspecs/xmss/test_merkle_tree.py | 104 ++++++ tests/lean_spec/subspecs/xmss/test_prf.py | 67 ++++ 11 files changed, 1561 insertions(+), 22 deletions(-) create mode 100644 src/lean_spec/subspecs/xmss/hypercube.py create mode 100644 src/lean_spec/subspecs/xmss/merkle_tree.py create mode 100644 src/lean_spec/subspecs/xmss/prf.py create mode 100644 src/lean_spec/subspecs/xmss/tweak_hash.py create mode 100644 tests/lean_spec/subspecs/xmss/test_hypercube.py create mode 100644 tests/lean_spec/subspecs/xmss/test_merkle_tree.py create mode 100644 tests/lean_spec/subspecs/xmss/test_prf.py diff --git a/src/lean_spec/subspecs/xmss/__init__.py b/src/lean_spec/subspecs/xmss/__init__.py index 9e381730..b992bfe6 100644 --- a/src/lean_spec/subspecs/xmss/__init__.py +++ b/src/lean_spec/subspecs/xmss/__init__.py @@ -7,7 +7,14 @@ from .constants import LIFETIME, MESSAGE_LENGTH from .interface import key_gen, sign, verify -from .structures import HashTreeOpening, PublicKey, SecretKey, Signature +from .merkle_tree import build_tree, get_path, get_root, verify_path +from .structures import ( + HashTree, + HashTreeOpening, + PublicKey, + SecretKey, + Signature, +) __all__ = [ "key_gen", @@ -19,4 +26,9 @@ "HashTreeOpening", "LIFETIME", "MESSAGE_LENGTH", + "build_tree", + "get_path", + "get_root", + "verify_path", + "HashTree", ] diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index e2f14bb0..eea67dc5 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -16,6 +16,20 @@ from ..koalabear import Fp +PRF_KEY_LENGTH: int = 32 +"""The length of the PRF secret key in bytes.""" + + +MAX_TRIES: int = 100_000 +""" +How often one should try at most to resample a random value. + +This is currently based on experiments with the Rust implementation. + +Should probably be modified in production. +""" + + # ================================================================= # Core Scheme Configuration # ================================================================= diff --git a/src/lean_spec/subspecs/xmss/hypercube.py b/src/lean_spec/subspecs/xmss/hypercube.py new file mode 100644 index 00000000..a553480d --- /dev/null +++ b/src/lean_spec/subspecs/xmss/hypercube.py @@ -0,0 +1,247 @@ +""" +Implements the mathematical operations for hypercube layers. + +This module provides the necessary functions to work with vertices in a +v-dimensional hypercube with coordinates in the range [0, w-1]. A key concept +is the partitioning of the hypercube's vertices into "layers". A vertex belongs +to layer `d`, where `d` is its distance from the sink vertex `(w-1, ..., w-1)`. + +The core functionalities are: +1. **Precomputation and Caching**: Efficiently calculates and caches the sizes + of each layer for different hypercube configurations (`w` and `v`). +2. **Mapping**: Provides bijective mappings between an integer index within a + layer and the unique vertex (a list of coordinates) it represents. + +This logic is a direct translation of the algorithms described in the paper +"At the top of the hypercube" (eprint 2025/889) +and the reference Rust implementation: https://github.com/b-wagn/hash-sig. +""" + +from __future__ import annotations + +import bisect +from functools import lru_cache +from typing import List, Tuple + +from pydantic import BaseModel, ConfigDict + +MAX_DIMENSION = 100 +"""The maximum dimension `v` for which layer sizes will be precomputed.""" + + +class LayerInfo(BaseModel): + """ + Stores the precomputed sizes and cumulative sums for a + hypercube's layers. + """ + + model_config = ConfigDict(frozen=True) + sizes: List[int] + """The number of vertices in each layer `d`.""" + prefix_sums: List[int] + """ + The cumulative number of vertices up to and including layer `d`. + + `prefix_sums[d] = sizes[0] + ... + sizes[d]`. + """ + + def sizes_sum_in_range(self, start: int, end: int) -> int: + """Calculates the sum of sizes in an inclusive range [start, end].""" + if start > end: + return 0 + if start == 0: + return self.prefix_sums[end] + else: + return self.prefix_sums[end] - self.prefix_sums[start - 1] + + +@lru_cache(maxsize=None) +def prepare_layer_info(w: int) -> List[LayerInfo]: + """ + Precomputes and caches the number of vertices in each layer of a hypercube. + + This function is a crucial precomputation step for the mapping algorithms + used in the "Top Level Target Sum" encoding. + It calculates the size of every layer for hypercubes with a given + base `w` (where coordinates are in `[0, w-1]`) for all dimensions + `v` up to `MAX_DIMENSION`. + + The calculation is inductive: the layer sizes for a `v`-dimensional + hypercube are efficiently derived from the already-computed sizes + of a `(v-1)`-dimensional hypercube, based on the recurrence relation from + Lemma 8 of the paper "At the top of the hypercube" (eprint 2025/889). + + Args: + w: The base of the hypercube. + + Returns: + A list where the element at index `v` is a `LayerInfo` object + containing the layer sizes for a `v`-dimensional hypercube. + """ + # Initialize a list to store the results for each dimension `v`. + # + # Index 0 is unused to allow for direct indexing, e.g., `all_info[v]`. + all_info = [LayerInfo(sizes=[], prefix_sums=[])] * (MAX_DIMENSION + 1) + + # -- BASE CASE -- + # For a 1-dimensional hypercube (v=1), which is just a line of `w` points + # with coordinates [0], [1], ..., [w-1]. + # The distance `d` from the sink `[w-1]` is simply `(w-1) - coordinate`. + + # Each of the `w` possible layers contains exactly one vertex. + dim1_sizes = [1] * w + # The prefix sums (cumulative sizes) are therefore just [1, 2, 3, ..., w]. + dim1_prefix_sums = list(range(1, w + 1)) + # Store the result for v=1, which will seed the inductive step. + all_info[1] = LayerInfo(sizes=dim1_sizes, prefix_sums=dim1_prefix_sums) + + # -- INDUCTIVE STEP -- + # Now, build the layer info for all higher dimensions up to the maximum. + for v in range(2, MAX_DIMENSION + 1): + # The maximum possible distance `d` in a v-dimensional hypercube. + max_d = (w - 1) * v + # Retrieve the already-computed data for the previous dimension (v-1). + prev_layer_info = all_info[v - 1] + + # This list will store the computed size of each layer `d` + # for dimension `v`. + current_sizes: List[int] = [] + for d in range(max_d + 1): + # Implements the recurrence l_d(v) = Σ l_{d-j}(v-1) from the paper. + # `j` is one coordinate's distance contribution. + + # Calculate the valid range [j_min, j_max] for `j`. + j_min = max(0, d - (w - 1) * (v - 1)) + j_max = min(w - 1, d) + + # Translate the sum over `j` to an index range `k`, + # where k = d - j. + # This allows for an efficient lookup using prefix sums. + k_min = d - j_max + k_max = d - j_min + + # Efficiently calculate the sum using the precomputed prefix sums + # from the previous dimension's `LayerInfo`. + layer_size = prev_layer_info.sizes_sum_in_range(k_min, k_max) + current_sizes.append(layer_size) + + # After computing all layer sizes for dimension `v`, we compute their + # prefix sums. + # This is needed for the *next* iteration (for dimension v+1). + current_prefix_sums: List[int] = [] + current_sum = 0 + for size in current_sizes: + current_sum += size + current_prefix_sums.append(current_sum) + + # Store the complete layer info for the current dimension `v`. + all_info[v] = LayerInfo( + sizes=current_sizes, prefix_sums=current_prefix_sums + ) + + # Return the complete table of layer information for the given base `w`. + return all_info + + +def get_layer_size(w: int, v: int, d: int) -> int: + """Returns the size of a specific layer `d` in a `(w, v)` hypercube.""" + return prepare_layer_info(w)[v].sizes[d] + + +def get_hypercube_part_size(w: int, v: int, d: int) -> int: + """Returns the total size of layers 0 to `d` (inclusive).""" + return prepare_layer_info(w)[v].prefix_sums[d] + + +def find_layer(w: int, v: int, x: int) -> Tuple[int, int]: + """ + Given a global index `x`, finds the layer `d` it belongs to and its + local index (`remainder`) within that layer. + + Args: + w: The hypercube base. + v: The hypercube dimension. + x: The global index of a vertex (from 0 to w**v - 1). + + Returns: + A tuple `(d, remainder)`. + """ + prefix_sums = prepare_layer_info(w)[v].prefix_sums + # Use binary search to efficiently find the correct layer. + d = bisect.bisect_left(prefix_sums, x + 1) + + if d == 0: + remainder = x + else: + remainder = x - prefix_sums[d - 1] + + return d, remainder + + +def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]: + """ + Maps an integer index `x` to a unique vertex in a specific hypercube layer. + + This function provides a bijective mapping from an integer `x` + (derived from a hash) to a unique list of `v` coordinates, + `[a_0, ..., a_{v-1}]`. + + The algorithm works iteratively, determining one coordinate at a time. + + Args: + w: The hypercube base (coordinates are in `[0, w-1]`). + v: The hypercube dimension (the number of coordinates). + d: The target layer, defined by its distance from the sink vertex. + x: The integer index within the layer `d`, must be `0 <= x < size(d)`. + + Returns: + A list of `v` integers representing the coordinates of the vertex. + """ + layer_info_cache = prepare_layer_info(w) + + # Validate that the input index `x` is valid for the target layer. + layer_size = layer_info_cache[v].sizes[d] + if x >= layer_size: + raise ValueError("Index x is out of bounds for the given layer.") + + vertex: List[int] = [] + # Track remaining distance and index. + d_curr, x_curr = d, x + + # Determine each of the first v-1 coordinates iteratively. + for i in range(1, v): + dim_remaining = v - i + prev_dim_layer_info = layer_info_cache[dim_remaining] + + # This loop finds which block of sub-hypercubes the index `x_curr` + # falls into. + # + # It skips over full blocks by subtracting their size + # from `x_curr` until the correct one is found. + ji = -1 # Sentinel value + range_start = max(0, d_curr - (w - 1) * dim_remaining) + for j in range(range_start, min(w, d_curr + 1)): + count = prev_dim_layer_info.sizes[d_curr - j] + if x_curr >= count: + x_curr -= count + else: + ji = j # Found the correct block. + break + + if ji == -1: + raise RuntimeError( + "Internal logic error: failed to find coordinate" + ) + + # Convert the block's distance contribution `ji` to a coordinate `ai`. + ai = w - 1 - ji + vertex.append(ai) + + # Update the remaining distance for the next, smaller subproblem. + d_curr -= ji + + # The final coordinate is uniquely determined by the remaining values. + last_coord = w - 1 - x_curr - d_curr + vertex.append(last_coord) + + return vertex diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index e8e7e492..89aa19a8 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -8,9 +8,21 @@ from __future__ import annotations -from typing import Tuple - -from .structures import PublicKey, SecretKey, Signature +from typing import List, Tuple + +from .constants import BASE, DIMENSION, LIFETIME, LOG_LIFETIME +from .merkle_tree import build_tree, get_root +from .prf import apply as apply_prf +from .prf import key_gen as prf_key_gen +from .structures import HashDigest, PublicKey, SecretKey, Signature +from .tweak_hash import ( + TreeTweak, + hash_chain, + rand_parameter, +) +from .tweak_hash import ( + apply as apply_tweakable_hash, +) def key_gen( @@ -19,9 +31,17 @@ def key_gen( """ Generates a new cryptographic key pair. This is a **randomized** algorithm. - This function is a placeholder. In a real implementation, it would involve - generating a master secret, deriving all one-time keys, and constructing - the full Merkle tree. + This function executes the full key generation process: + 1. Generates a master secret (PRF key) and a public hash parameter. + 2. For each epoch in the active range, it uses the PRF to derive the + secret starting points for all `DIMENSION` hash chains. + 3. It computes the public endpoint of each chain by + hashing `BASE - 1` times. + 4. The list of all chain endpoints for an epoch forms the + one-time public key. + This one-time public key is hashed to create a single Merkle leaf. + 5. A Merkle tree is built over all generated leaves, and its root becomes + part of the final public key. Args: activation_epoch: The starting epoch for which this key is active. @@ -33,10 +53,52 @@ def key_gen( - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - The canonical Rust implementation: https://github.com/b-wagn/hash-sig """ - raise NotImplementedError( - "key_gen is not part of this specification. " - "See the Rust reference implementation." + # Validate the activation range against the scheme's total lifetime. + if activation_epoch + num_active_epochs > LIFETIME: + raise ValueError("Activation range exceeds the key's lifetime.") + + # Generate the random public parameter `P` and the master PRF key. + parameter = rand_parameter() + prf_key = prf_key_gen() + + # For each epoch, generate the corresponding Merkle leaf hash. + leaf_hashes: List[HashDigest] = [] + for epoch in range(activation_epoch, activation_epoch + num_active_epochs): + # For each epoch, we compute `DIMENSION` chain endpoints. + chain_ends: List[HashDigest] = [] + for chain_index in range(DIMENSION): + # Derive the secret start of the chain from the master key. + start_digest = apply_prf(prf_key, epoch, chain_index) + # Compute the public end of the chain by hashing BASE - 1 times. + end_digest = hash_chain( + parameter=parameter, + epoch=epoch, + chain_index=chain_index, + start_step=0, + num_steps=BASE - 1, + start_digest=start_digest, + ) + chain_ends.append(end_digest) + + # The Merkle leaf is the hash of all chain endpoints for this epoch. + leaf_tweak = TreeTweak(level=0, index=epoch) + leaf_hash = apply_tweakable_hash(parameter, leaf_tweak, chain_ends) + leaf_hashes.append(leaf_hash) + + # Build the Merkle tree over the generated leaves. + tree = build_tree(LOG_LIFETIME, activation_epoch, parameter, leaf_hashes) + root = get_root(tree) + + # Assemble and return the public and secret keys. + pk = PublicKey(root=root, parameter=parameter) + sk = SecretKey( + prf_key=prf_key, + tree=tree, + parameter=parameter, + activation_epoch=activation_epoch, + num_active_epochs=num_active_epochs, ) + return pk, sk def sign(sk: SecretKey, epoch: int, message: bytes) -> Signature: @@ -44,17 +106,65 @@ def sign(sk: SecretKey, epoch: int, message: bytes) -> Signature: Produces a digital signature for a given message at a specific epoch. This is a **randomized** algorithm. - This function is a placeholder. The signing process involves encoding the - message, generating a one-time signature, and providing a Merkle path. - **CRITICAL**: This function must never be called twice with the same secret key and epoch for different messages, as this would compromise security. + The signing process involves: + 1. Repeatedly attempting to encode the message with fresh randomness + (`rho`) until a valid codeword is found. + 2. Computing the one-time signature, which consists of intermediate + values from the secret hash chains, determined by the digits of + the codeword. + 3. Retrieving the Merkle authentication path for the given epoch. + For the formal specification of this process, please refer to: - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - The canonical Rust implementation: https://github.com/b-wagn/hash-sig """ + # # 1. Check that the key is active for the requested signing epoch. + # active_range = range( + # sk.activation_epoch, sk.activation_epoch + sk.num_active_epochs + # ) + # if epoch not in active_range: + # raise ValueError("Key is not active for the specified epoch.") + + # # 2. Find a valid message encoding by trying different randomness `rho`. + # codeword = None + # rho = None + # for _ in range(MAX_TRIES): + # current_rho = rand_rho() + # current_codeword = encode(sk.parameter, message, current_rho, epoch) + # if current_codeword is not None: + # codeword = current_codeword + # rho = current_rho + # break + + # # If no valid encoding is found after many tries, signing fails. + # if codeword is None or rho is None: + # raise RuntimeError("Failed to find a valid message encoding.") + + # # 3. Compute the one-time signature hashes based on the codeword. + # ots_hashes: List[HashDigest] = [] + # for chain_index, steps in enumerate(codeword): + # # Derive the secret start of the chain from the master key. + # start_digest = apply_prf(sk.prf_key, epoch, chain_index) + # # Walk the chain for the number of steps given by the codeword digit. + # ots_digest = hash_chain( + # parameter=sk.parameter, + # epoch=epoch, + # chain_index=chain_index, + # start_step=0, + # num_steps=steps, + # start_digest=start_digest, + # ) + # ots_hashes.append(ots_digest) + + # # 4. Get the Merkle authentication path for the current epoch. + # path = get_path(sk.tree, epoch) + + # # 5. Assemble and return the final signature. + # return Signature(path=path, rho=rho, hashes=ots_hashes) raise NotImplementedError( "sign is not part of this specification. " "See the Rust reference implementation." @@ -104,6 +214,40 @@ def verify(pk: PublicKey, epoch: int, message: bytes, sig: Signature) -> bool: - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - The canonical Rust implementation: https://github.com/b-wagn/hash-sig """ + # # 1. Re-encode the message using the randomness `rho` from the signature. + # # If the encoding is invalid, the signature is invalid. + # codeword = encode(pk.parameter, message, sig.rho, epoch) + # if codeword is None: + # return False + + # # 2. Reconstruct the one-time public key (the list of chain endpoints). + # chain_ends: List[HashDigest] = [] + # for chain_index, xi in enumerate(codeword): + # # The signature provides the hash value after `xi` steps. + # start_digest = sig.hashes[chain_index] + # # We must perform the remaining `BASE - 1 - xi` steps + # # to get to the end. + # num_steps_remaining = BASE - 1 - xi + # end_digest = hash_chain( + # parameter=pk.parameter, + # epoch=epoch, + # chain_index=chain_index, + # start_step=xi, + # num_steps=num_steps_remaining, + # start_digest=start_digest, + # ) + # chain_ends.append(end_digest) + + # # 3. Verify the Merkle path. This function internally hashes `chain_ends` + # # to get the leaf node and then climbs the tree to recompute the root. + # return verify_path( + # parameter=pk.parameter, + # root=pk.root, + # position=epoch, + # leaf_parts=chain_ends, + # opening=sig.path, + # ) raise NotImplementedError( - "verify will be implemented in a future update to the specification." + "sign is not part of this specification. " + "See the Rust reference implementation." ) diff --git a/src/lean_spec/subspecs/xmss/merkle_tree.py b/src/lean_spec/subspecs/xmss/merkle_tree.py new file mode 100644 index 00000000..d6eaff4a --- /dev/null +++ b/src/lean_spec/subspecs/xmss/merkle_tree.py @@ -0,0 +1,219 @@ +""" +Implements the sparse Merkle tree used in the Generalized XMSS scheme. + +This module provides the data structures and algorithms for building a Merkle +tree over a contiguous subset of leaves, computing authentication paths, and +verifying those paths. + +The key features are: +1. **Sparsity**: The tree can represent a massive address space (e.g., 2^32 + leaves) while only storing the nodes relevant to a smaller, active range of + leaves (e.g., 2^20 leaves). +2. **Random Padding**: To simplify the logic for pairing sibling nodes, the + active layers are padded with random hash values to ensure they always + start at an even index and end at an odd index. +""" + +from __future__ import annotations + +from typing import List + +from .structures import ( + HashDigest, + HashTree, + HashTreeLayer, + HashTreeOpening, + Parameter, +) +from .tweak_hash import TreeTweak, rand_domain +from .tweak_hash import apply as apply_tweakable_hash + + +def _get_padded_layer( + nodes: List[HashDigest], start_index: int +) -> HashTreeLayer: + """ + Pads a layer to ensure its nodes can always be paired up. + + This helper function adds random padding to the start and/or end of a list + of nodes to enforce an invariant: every layer must start at an even index + and end at an odd index. This guarantees that every node has a sibling, + simplifying the construction of the next layer up. + + Args: + nodes: The list of active nodes for the current layer. + start_index: The starting index of the first node in `nodes`. + + Returns: + A new `HashTreeLayer` with padding applied. + """ + nodes_with_padding: List[HashDigest] = [] + end_index = start_index + len(nodes) - 1 + + # Prepend random padding if the layer starts at an odd index. + if start_index % 2 == 1: + nodes_with_padding.append(rand_domain()) + + # The actual start index of the padded layer is always the even + # number at or immediately before the original start_index. + actual_start_index = start_index - (start_index % 2) + + # Add the actual node content. + nodes_with_padding.extend(nodes) + + # Append random padding if the layer ends at an even index. + if end_index % 2 == 0: + nodes_with_padding.append(rand_domain()) + + return HashTreeLayer( + start_index=actual_start_index, nodes=nodes_with_padding + ) + + +def build_tree( + depth: int, + start_index: int, + parameter: Parameter, + leaf_hashes: List[HashDigest], +) -> HashTree: + """ + Builds a new sparse Merkle tree from a list of leaf hashes. + + The construction proceeds bottom-up, from the leaf layer to the root. + At each level, pairs of sibling nodes are hashed to create their parents + for the next level up. + + Args: + depth: The depth of the tree (e.g., 32 for a 2^32 leaf space). + start_index: The index of the first leaf in `leaf_hashes`. + parameter: The public parameter `P` for the hash function. + leaf_hashes: The list of pre-hashed leaf nodes to build the tree on. + + Returns: + The fully constructed `HashTree` object. + """ + # Start with the leaf hashes and apply the initial padding. + layers: List[HashTreeLayer] = [] + current_layer = _get_padded_layer(leaf_hashes, start_index) + layers.append(current_layer) + + # Iterate from the leaf layer (level 0) up to the root. + for level in range(depth): + parents: List[HashDigest] = [] + # Group the current layer's nodes into pairs of siblings. + for i, children in enumerate( + zip( + current_layer.nodes[0::2], + current_layer.nodes[1::2], + strict=False, + ) + ): + # Calculate the position of the parent node in the next level up. + parent_index = (current_layer.start_index // 2) + i + # Create the tweak for hashing these two children. + tweak = TreeTweak(level=level + 1, index=parent_index) + # Hash the left and right children to get their parent. + parent_node = apply_tweakable_hash( + parameter, tweak, list(children) + ) + parents.append(parent_node) + + # Pad the new list of parents to prepare for the next iteration. + new_start_index = current_layer.start_index // 2 + current_layer = _get_padded_layer(parents, new_start_index) + layers.append(current_layer) + + return HashTree(depth=depth, layers=layers) + + +def get_root(tree: HashTree) -> HashDigest: + """Extracts the root digest from a constructed Merkle tree.""" + # The root is the single node in the final layer. + return tree.layers[-1].nodes[0] + + +def get_path(tree: HashTree, position: int) -> HashTreeOpening: + """ + Computes the Merkle authentication path for a leaf at a given position. + + The path consists of the list of sibling nodes required to reconstruct the + root, starting from the leaf's sibling and going up the tree. + + Args: + tree: The `HashTree` from which to extract the path. + position: The absolute index of the leaf for which to generate path. + + Returns: + A `HashTreeOpening` object containing the co-path. + """ + co_path: List[HashDigest] = [] + current_position = position + + # Iterate from the bottom layer (level 0) up to the layer below the root. + for level in range(tree.depth): + # Determine the sibling's position using an XOR operation. + sibling_position = current_position ^ 1 + # Find the sibling's index within the sparse `nodes` vector. + layer = tree.layers[level] + sibling_index_in_vec = sibling_position - layer.start_index + # Add the sibling's hash to the co-path. + co_path.append(layer.nodes[sibling_index_in_vec]) + # Move up to the parent's position for the next iteration. + current_position //= 2 + + return HashTreeOpening(siblings=co_path) + + +def verify_path( + parameter: Parameter, + root: HashDigest, + position: int, + leaf_parts: List[HashDigest], + opening: HashTreeOpening, +) -> bool: + """ + Verifies a Merkle authentication path. + + This function reconstructs a candidate root by starting with the leaf node + and repeatedly hashing it with the sibling nodes provided in the opening + path. + + The verification succeeds if the candidate root matches the true root. + + Args: + parameter: The public parameter `P` for the hash function. + root: The known, trusted Merkle root. + position: The absolute index of the leaf being verified. + leaf_parts: The list of digests that constitute the original leaf. + opening: The `HashTreeOpening` object containing the sibling path. + + Returns: + `True` if the path is valid, `False` otherwise. + """ + # The first step is to hash the constituent parts of the leaf to get + # the actual node at layer 0 of the tree. + leaf_tweak = TreeTweak(level=0, index=position) + current_node = apply_tweakable_hash(parameter, leaf_tweak, leaf_parts) + + # Iterate up the tree, hashing the current node with its sibling from + # the path at each level. + current_position = position + for level, sibling_node in enumerate(opening.siblings): + # Determine if the current node is a left or right child. + if current_position % 2 == 0: + # Current node is a left child; sibling is on the right. + children = [current_node, sibling_node] + else: + # Current node is a right child; sibling is on the left. + children = [sibling_node, current_node] + + # Move up to the parent's position for the next iteration. + current_position //= 2 + # Create the tweak for the parent's level and position. + parent_tweak = TreeTweak(level=level + 1, index=current_position) + # Hash the children to compute the parent node. + current_node = apply_tweakable_hash(parameter, parent_tweak, children) + + # After iterating through the entire path, the final computed node + # should be the root of the tree. + return current_node == root diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py new file mode 100644 index 00000000..65fd5178 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -0,0 +1,135 @@ +""" +Defines the pseudorandom function (PRF) used in the signature scheme. + +PRF based on the SHAKE128 extendable-output function (XOF). + +The PRF is used to derive the secret starting points of the hash chains +for each epoch from a single master secret key. +""" + +from __future__ import annotations + +import hashlib +import os + +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import HASH_LEN_FE, PRF_KEY_LENGTH +from lean_spec.subspecs.xmss.structures import HashDigest, PRFKey +from lean_spec.types.uint64 import Uint64 + +PRF_DOMAIN_SEP: bytes = bytes( + [ + 0xAE, + 0xAE, + 0x22, + 0xFF, + 0x00, + 0x01, + 0xFA, + 0xFF, + 0x21, + 0xAF, + 0x12, + 0x00, + 0x01, + 0x11, + 0xFF, + 0x00, + ] +) +""" +A 16-byte domain separator to ensure PRF outputs are unique to this context. + +This prevents any potential conflicts if the same underlying hash function +(SHAKE128) were used for other purposes in the system. +""" + +PRF_BYTES_PER_FE: int = 8 +""" +The number of bytes of SHAKE128 output used to generate one field element. + +We use 8 bytes (64 bits) of pseudorandom output, which is then reduced +modulo the 31-bit field prime `P`. This provides a significant statistical +safety margin to ensure the resulting field element is close to uniformly +random. +""" + + +PRFOutput = HashDigest +""" +A type alias for the output of the PRF. +It is a list of field elements. +""" + + +def key_gen() -> PRFKey: + """ + Generates a cryptographically secure random key for the PRF. + + This function sources randomness from the operating system's entropy pool. + + Returns: + A new, randomly generated PRF key of `PRF_KEY_LENGTH` bytes. + """ + # Use os.urandom for cryptographically secure random bytes. + return os.urandom(PRF_KEY_LENGTH) + + +def apply(key: PRFKey, epoch: int, chain_index: Uint64) -> PRFOutput: + """ + Applies the PRF to derive the secret values for a specific epoch and chain. + + The function computes `SHAKE128(DOMAIN_SEP || key || epoch || chain_index)` + and interprets the output as a list of field elements. + + Args: + key: The secret PRF key. + epoch: The epoch number (a 32-bit unsigned integer). + chain_index: The index of the hash chain (a 64-bit unsigned integer). + + Returns: + A list of `DIMENSION` field elements, which are the secret starting + points for the hash chains of the specified epoch. + """ + # Create a new SHAKE128 hash instance. + hasher = hashlib.shake_128() + + # Absorb the domain separator to contextualize the hash. + hasher.update(PRF_DOMAIN_SEP) + + # Absorb the secret key. + hasher.update(key) + + # Absorb the epoch, represented as a 4-byte big-endian integer. + # + # This ensures that each epoch produces a unique set of secrets. + hasher.update(epoch.to_bytes(4, "big")) + + # Absorb the chain index, as an 8-byte big-endian integer. + # + # This is used to derive a unique start value for each hash chain. + hasher.update(chain_index.to_bytes(8, "big")) + + # Determine the total number of bytes to extract from the SHAKE output. + # + # For key generation, we need one field element per chain. + num_bytes_to_read = PRF_BYTES_PER_FE * HASH_LEN_FE + prf_output_bytes = hasher.digest(num_bytes_to_read) + + # Convert the byte output into a list of field elements. + output_elements: PRFOutput = [] + for i in range(HASH_LEN_FE): + # Extract an 8-byte chunk for the current field element. + start = i * PRF_BYTES_PER_FE + end = start + PRF_BYTES_PER_FE + chunk = prf_output_bytes[start:end] + + # Convert the chunk to a large integer. + integer_value = int.from_bytes(chunk, "big") + + # Reduce the integer modulo the field prime `P`. + # + # The Fp constructor handles the modulo operation automatically. + output_elements.append(Fp(value=integer_value)) + + return output_elements diff --git a/src/lean_spec/subspecs/xmss/structures.py b/src/lean_spec/subspecs/xmss/structures.py index 306e9948..78ccc667 100644 --- a/src/lean_spec/subspecs/xmss/structures.py +++ b/src/lean_spec/subspecs/xmss/structures.py @@ -5,7 +5,17 @@ from pydantic import BaseModel, ConfigDict, Field from ..koalabear import Fp -from .constants import HASH_LEN_FE, PARAMETER_LEN, RAND_LEN_FE +from .constants import HASH_LEN_FE, PARAMETER_LEN, PRF_KEY_LENGTH, RAND_LEN_FE + +PRFKey = Annotated[ + bytes, Field(min_length=PRF_KEY_LENGTH, max_length=PRF_KEY_LENGTH) +] +""" +A type alias for the PRF secret key. + +It is a byte string of `PRF_KEY_LENGTH` bytes. +""" + HashDigest = Annotated[ List[Fp], Field(min_length=HASH_LEN_FE, max_length=HASH_LEN_FE) @@ -43,6 +53,41 @@ class HashTreeOpening(BaseModel): ) +class HashTreeLayer(BaseModel): + """ + Represents a single layer within the sparse Merkle tree. + + Attributes: + start_index: The index of the first node in this layer within the full + conceptual tree. + nodes: A list of the actual hash digests stored for this layer. + """ + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + start_index: int + """The starting index of the first node in this layer.""" + nodes: List[HashDigest] + """A list of the actual hash digests stored for this layer.""" + + +class HashTree(BaseModel): + """ + The complete sparse Merkle tree structure. + + Attributes: + depth: The total depth of the tree (e.g., 32 for a 2^32 leaf space). + layers: A list of `HashTreeLayer` objects, from the leaf hashes + (layer 0) up to the layer just below the root. + """ + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + depth: int + """The total depth of the tree (e.g., 32 for a 2^32 leaf space).""" + layers: List[HashTreeLayer] + """""A list of `HashTreeLayer` objects, from the leaf hashes + (layer 0) up to the layer just below the root.""" + + class PublicKey(BaseModel): """The public key for the Generalized XMSS scheme.""" @@ -65,11 +110,11 @@ class Signature(BaseModel): class SecretKey(BaseModel): - """ - Placeholder for the secret key. + """The secret key for the Generalized XMSS scheme.""" - Note: The full secret key structure is not specified here as it is not - needed for verification. - """ - - pass + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + prf_key: PRFKey + tree: HashTree + parameter: Parameter + activation_epoch: int + num_active_epochs: int diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py new file mode 100644 index 00000000..67f18e79 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -0,0 +1,327 @@ +""" +Defines the Tweakable Hash function using Poseidon2. + +This module implements the core hashing logic for the XMSS scheme, including: +1. **Tweak Encoding**: Domain-separating different hash usages (chains/trees). +2. **Poseidon2 Compression**: Hashing fixed-size inputs. +3. **Poseidon2 Sponge**: Hashing variable-length inputs (e.g., leaf nodes). +4. **A unified `apply` function** that dispatches to the correct mode. +5. **A `chain` utility** to perform repeated hashing for WOTS chains. +""" + +from __future__ import annotations + +import secrets +from itertools import chain +from typing import List, Union + +from pydantic import Field + +from lean_spec.types.base import StrictBaseModel + +from ..koalabear import Fp, P +from ..poseidon2.permutation import PARAMS_16, PARAMS_24, permute +from .constants import ( + CAPACITY, + DIMENSION, + HASH_LEN_FE, + PARAMETER_LEN, + TWEAK_LEN_FE, + TWEAK_PREFIX_CHAIN, + TWEAK_PREFIX_TREE, +) +from .structures import HashDigest, Parameter + + +class TreeTweak(StrictBaseModel): + """A tweak used for hashing nodes within the Merkle tree.""" + + level: int = Field(ge=0, description="The level in the Merkle tree.") + index: int = Field(ge=0, description="The node's index within that level.") + + +class ChainTweak(StrictBaseModel): + """A tweak used for hashing elements within a WOTS+ hash chain.""" + + epoch: int = Field(ge=0, description="The signature epoch.") + chain_index: int = Field(ge=0, description="The index of the hash chain.") + step: int = Field(ge=0, description="The step number within the chain.") + + +Tweak = Union[TreeTweak, ChainTweak] +"""A type alias representing any valid tweak structure.""" + + +def encode_tweak(tweak: Tweak, length: int) -> List[Fp]: + """ + Encodes a tweak structure into a list of field elements for hashing. + + This function packs the tweak's integer components into a single large + integer, then performs a base-`P` decomposition to get field elements. + This ensures a unique and deterministic mapping from any tweak to a format + consumable by the hash function. + + The packing scheme is designed to be injective: + - `TreeTweak`: `(level << 40) | (index << 8) | TWEAK_PREFIX_TREE` + - `ChainTweak`: `(epoch << 24) | (chain_index << 16) + | (step << 8) | TWEAK_PREFIX_CHAIN` + + Args: + tweak: The `TreeTweak` or `ChainTweak` object. + length: The desired number of field elements in the output list. + + Returns: + A list of `length` field elements representing the encoded tweak. + """ + # Pack the tweak's integer fields into a single large integer. + # + # A unique prefix is included for domain separation. + if isinstance(tweak, TreeTweak): + acc = ( + (tweak.level << 40) | (tweak.index << 8) | TWEAK_PREFIX_TREE.value + ) + else: + acc = ( + (tweak.epoch << 24) + | (tweak.chain_index << 16) + | (tweak.step << 8) + | TWEAK_PREFIX_CHAIN.value + ) + + # Decompose the large integer `acc` into a list of field elements. + # This is a standard base-P decomposition. + # + # The number of elements is determined by the `length` parameter. + elements: List[Fp] = [] + for _ in range(length): + elements.append(Fp(value=acc)) + acc //= P + return elements + + +def poseidon_compress( + input_vec: List[Fp], width: int, output_len: int +) -> HashDigest: + """ + A low-level wrapper for Poseidon2 in compression mode. + + It computes: `Truncate(Permute(padded_input) + padded_input)`. + + Args: + input_vec: The input data to be hashed. + width: The state width of the Poseidon2 permutation (16 or 24). + output_len: The number of field elements in the output digest. + + Returns: + A hash digest of `output_len` field elements. + """ + # Select the correct permutation parameters based on the state width. + params = PARAMS_16 if width == 16 else PARAMS_24 + + # Create a fixed-width buffer and copy the input, padding with zeros. + padded_input = [Fp(value=0)] * width + padded_input[: len(input_vec)] = input_vec + + # Apply the Poseidon2 permutation. + permuted_state = permute(padded_input, params) + + # Apply the feed-forward step, adding the input back element-wise. + final_state = [ + p + i for p, i in zip(permuted_state, padded_input, strict=True) + ] + + # Truncate the state to the desired output length and return. + return final_state[:output_len] + + +def _poseidon_safe_domain_separator( + lengths: List[int], capacity_len: int +) -> List[Fp]: + """ + Computes a domain separator for the sponge construction. + + This function hashes a list of length parameters to create a unique + "capacity value" that configures the sponge for a specific hashing task. + + Args: + lengths: A list of integer parameters defining the hash context. + capacity_len: The desired length of the output capacity value. + + Returns: + A list of `capacity_len` field elements. + """ + # Pack the length parameters into a single large integer. + acc = 0 + for length in lengths: + acc = (acc << 32) | length + + # Decompose the integer into a list of 24 field elements for hashing. + # + # 24 is the fixed input width for this specific domain separation hash. + input_vec: List[Fp] = [] + for _ in range(24): + input_vec.append(Fp(value=acc)) + acc //= P + + # Compress the decomposed vector to produce the capacity value. + return poseidon_compress(input_vec, 24, capacity_len) + + +def _poseidon_sponge( + input_vec: List[Fp], capacity_value: List[Fp], output_len: int +) -> HashDigest: + """ + A low-level wrapper for Poseidon2 in sponge mode. + + Args: + input_vec: The input data of arbitrary length. + capacity_value: A domain-separating value. + output_len: The number of field elements in the output digest. + + Returns: + A hash digest of `output_len` field elements. + """ + # Use the width-24 permutation for the sponge. + params = PARAMS_24 + width = params.width + rate = width - len(capacity_value) + + # Pad the input vector to be an exact multiple of the rate. + num_extra = (rate - (len(input_vec) % rate)) % rate + padded_input = input_vec + [Fp(value=0)] * num_extra + + # Initialize the state with the capacity value. + state = [Fp(value=0)] * width + state[rate:] = capacity_value + + # Absorb the input in rate-sized chunks. + for i in range(0, len(padded_input), rate): + chunk = padded_input[i : i + rate] + # Add the chunk to the rate part of the state. + for j in range(rate): + state[j] += chunk[j] + # Apply the permutation. + state = permute(state, params) + + # Squeeze the output until enough elements have been generated. + output: HashDigest = [] + while len(output) < output_len: + output.extend(state[:rate]) + state = permute(state, params) + + # Truncate to the final output length and return. + return output[:output_len] + + +def apply( + parameter: Parameter, tweak: Tweak, message_parts: List[HashDigest] +) -> HashDigest: + """ + Applies the tweakable Poseidon2 hash function to a message. + + This function serves as the main entry point for all hashing operations. + It automatically selects the correct Poseidon2 mode (compression or sponge) + based on the number of message parts provided. + + Args: + parameter: The public parameter `P` for the hash function. + tweak: A `TreeTweak` or `ChainTweak` for domain separation. + message_parts: A list of one or more hash digests to be hashed. + + Returns: + A new hash digest of `HASH_LEN_FE` field elements. + """ + # Encode the tweak structure into field elements. + encoded_tweak = encode_tweak(tweak, TWEAK_LEN_FE) + + if len(message_parts) == 1: + # Case 1: Hashing a single digest (used in hash chains). + # + # We use the efficient width-16 compression mode. + input_vec = parameter + encoded_tweak + message_parts[0] + return poseidon_compress(input_vec, 16, HASH_LEN_FE) + + elif len(message_parts) == 2: + # Case 2: Hashing two digests (used for Merkle tree nodes). + # + # We use the width-24 compression mode. + input_vec = ( + parameter + encoded_tweak + message_parts[0] + message_parts[1] + ) + return poseidon_compress(input_vec, 24, HASH_LEN_FE) + + else: + # Case 3: Hashing many digests (used for the Merkle tree leaf). + # + # We use the robust sponge mode. + # First, flatten the list of message parts into a single vector. + flattened_message = list(chain.from_iterable(message_parts)) + input_vec = parameter + encoded_tweak + flattened_message + + # Create a domain separator based on the dimensions of the input. + lengths = [PARAMETER_LEN, TWEAK_LEN_FE, DIMENSION, HASH_LEN_FE] + capacity_value = _poseidon_safe_domain_separator(lengths, CAPACITY) + + return _poseidon_sponge(input_vec, capacity_value, HASH_LEN_FE) + + +def hash_chain( + parameter: Parameter, + epoch: int, + chain_index: int, + start_step: int, + num_steps: int, + start_digest: HashDigest, +) -> HashDigest: + """ + Performs repeated hashing to traverse a WOTS+ hash chain. + + Args: + parameter: The public parameter `P`. + epoch: The signature epoch. + chain_index: The index of the hash chain. + start_step: The starting step number in the chain. + num_steps: The number of hashing steps to perform. + start_digest: The digest to begin hashing from. + + Returns: + The final hash digest after `num_steps` applications. + """ + current_digest = start_digest + for i in range(num_steps): + # Create a unique tweak for the current position in the chain. + tweak = ChainTweak( + epoch=epoch, chain_index=chain_index, step=start_step + i + 1 + ) + # Apply the hash function. + current_digest = apply(parameter, tweak, [current_digest]) + return current_digest + + +def rand_parameter() -> Parameter: + """ + Generates a cryptographically secure random public parameter. + + Returns: + A new, randomly generated list of `PARAMETER_LEN` field elements. + """ + # For each element in the list, generate a secure random integer + # in the range [0, P-1] and convert it to a field element. + # `secrets.randbelow(P)` is used to avoid modulo bias. + return [Fp(value=secrets.randbelow(P)) for _ in range(PARAMETER_LEN)] + + +def rand_domain() -> HashDigest: + """ + Generates a cryptographically secure random hash digest. + + This is used for testing or as a starting point for hash chains + where a random seed is required. + + Returns: + A new, randomly generated list of `HASH_LEN_FE` field elements. + """ + # For each element in the list, generate a secure random integer + # in the range [0, P-1] and convert it to a field element. + # `secrets.randbelow(P)` is used to avoid modulo bias. + return [Fp(value=secrets.randbelow(P)) for _ in range(HASH_LEN_FE)] diff --git a/tests/lean_spec/subspecs/xmss/test_hypercube.py b/tests/lean_spec/subspecs/xmss/test_hypercube.py new file mode 100644 index 00000000..612ee858 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_hypercube.py @@ -0,0 +1,225 @@ +""" +Tests for the hypercube mathematical operations. + +This module provides extensive tests for the hypercube logic, ensuring that +the precomputation of layer sizes is correct and that the mappings between +integer indices and hypercube vertices are bijective and accurate. + +The tests are designed to be an exact equivalent of the Rust reference tests. +""" + +import math +from functools import lru_cache +from typing import List + +import pytest + +from lean_spec.subspecs.xmss.hypercube import ( + MAX_DIMENSION, + find_layer, + get_hypercube_part_size, + get_layer_size, + map_to_vertex, + prepare_layer_info, +) + + +def map_to_integer(w: int, v: int, d: int, a: List[int]) -> int: + """ + Maps a vertex `a` in layer `d` back to its integer index. + This is a direct translation of the reference Rust implementation. + """ + if len(a) != v: + raise ValueError("Vertex length must equal dimension v.") + + layer_info_cache = prepare_layer_info(w) + x_curr = 0 + + # Initialize `d_curr` with the distance contribution + # of the last coordinate. + d_curr = (w - 1) - a[v - 1] + + # Loop backwards from the second-to-last coordinate to the first. + for i in range(v - 2, -1, -1): + ji = (w - 1) - a[i] + d_curr += ji + + # Dimension of the subproblem at this stage. + rem_dim = v - 1 - i + prev_dim_layer_info = layer_info_cache[rem_dim] + + # Calculate the start of the summation range for the subproblem. + j_start = max(0, d_curr - (w - 1) * rem_dim) + + # Add the sizes of all blocks that come before the current one. + sum_start = d_curr - (ji - 1) + sum_end = d_curr - j_start + x_curr += prev_dim_layer_info.sizes_sum_in_range(sum_start, sum_end) + + # At the end, the incrementally built distance must + # equal the target distance. + assert d_curr == d + + # The final accumulated value is the index. + return x_curr + + +@lru_cache(maxsize=None) +def _binom(n: int, k: int) -> int: + """A cached binomial coefficient calculator (n choose k).""" + if k < 0 or k > n: + return 0 + return math.comb(n, k) + + +def _nb(k: int, m: int, n: int) -> int: + """ + Computes the number of integer vectors of dimension `n` with entries in + [0, m] that sum to `k`. This is equivalent to the coefficient of x^k in + the polynomial (1 + x + ... + x^m)^n. + """ + total = 0 + for s in range(k // (m + 1) + 1): + term = _binom(n, s) * _binom(k - s * (m + 1) + n - 1, n - 1) + if s % 2 == 0: + total += term + else: + total -= term + return total + + +def _prepare_layer_sizes_by_binom(w: int) -> list[list[int]]: + """ + A reference implementation to calculate layer sizes using binomial + coefficients. It's slower but simpler, making it good for validation. + """ + all_layers: List[List[int]] = [[] for _ in range(MAX_DIMENSION + 1)] + for v in range(1, MAX_DIMENSION + 1): + max_distance = (w - 1) * v + layer_sizes: List[int] = [] + for d in range(max_distance + 1): + # The sum of coordinates is `k = v * (w - 1) - d`. + # We need to find the number of ways to write `k` as a sum of `v` + # integers, each between 0 and `w-1`. + coord_sum = v * (w - 1) - d + layer_sizes.append(_nb(coord_sum, w - 1, v)) + all_layers[v] = layer_sizes + return all_layers + + +def test_prepare_layer_sizes_against_reference() -> None: + """ + Validates the optimized `prepare_layer_info` against the slower, + math-based reference implementation for a range of `w` values. + """ + for w in range(2, 7): + expected_sizes_by_v = _prepare_layer_sizes_by_binom(w) + actual_info_by_v = prepare_layer_info(w) + + for v in range(1, MAX_DIMENSION + 1): + # Note: The reference implementation returns reversed layer sizes. + # Layer `d` in our spec corresponds to sum `k = v*(w-1) - d`. + expected_sizes_reordered = list(reversed(expected_sizes_by_v[v])) + actual_sizes = actual_info_by_v[v].sizes + assert expected_sizes_reordered == actual_sizes + + +@pytest.mark.parametrize( + "w, v, d, expected_size", + [ + (2, 1, 0, 1), + (2, 1, 1, 2), + (3, 2, 0, 1), + (3, 2, 1, 3), + (3, 2, 2, 6), + (3, 2, 3, 8), + (3, 2, 4, 9), + (2, 3, 0, 1), + (2, 3, 1, 4), + (2, 3, 2, 7), + (2, 3, 3, 8), + ], +) +def test_get_hypercube_part_size( + w: int, v: int, d: int, expected_size: int +) -> None: + """ + Tests `get_hypercube_part_size` with known values from the Rust tests. + """ + assert get_hypercube_part_size(w, v, d) == expected_size + + +def test_find_layer_boundaries() -> None: + """ + Tests `find_layer` with specific boundary-crossing values. + """ + w, v = 3, 2 + # Layer sizes for (w=3, v=2) are [1, 2, 3, 2, 1] + # Prefix sums are [1, 3, 6, 8, 9] + + # x=0 is the 1st element, which is in layer 0. Remainder is 0. + assert find_layer(w, v, 0) == (0, 0) + # x=1 is the 2nd element, which is the 1st element in layer 1. + # Remainder is 0. + assert find_layer(w, v, 1) == (1, 0) + # x=2 is the 3rd element, which is the 2nd element in layer 1. + # Remainder is 1. + assert find_layer(w, v, 2) == (1, 1) + # x=3 is the 4th element, which is the 1st element in layer 2. + # Remainder is 0. + assert find_layer(w, v, 3) == (2, 0) + # x=5 is the 6th element, which is the third (last) element in layer 3. + # Remainder is 2. + assert find_layer(w, v, 5) == (2, 2) + # x=6 is the 7th element, which is the first element in layer 3. + # Remainder is 0. + assert find_layer(w, v, 6) == (3, 0) + # x=8 is the 9th element, which is the 1st element in layer 4. + # Remainder is 0. + assert find_layer(w, v, 8) == (4, 0) + + +def test_map_to_vertex_roundtrip() -> None: + """ + Tests the map_to_vertex and map_to_integer roundtrip for a small case. + This test is slow and only checks a limited range. + """ + w, v, d = 4, 8, 20 + max_x = get_layer_size(w, v, d) + + # Iterate through every possible index in a specific layer + # and check roundtrip + for x in range( + min(max_x, 100) + ): # Capped at 100 iterations to keep test fast + vertex = map_to_vertex(w, v, d, x) + + # Check that the vertex sum corresponds to the correct layer + coord_sum = sum(vertex) + assert (w - 1) * v - coord_sum == d + + # Check that mapping back gives the original index + x_reconstructed = map_to_integer(w, v, d, vertex) + assert x_reconstructed == x + + +def test_big_map() -> None: + """ + Tests the full map_to_vertex -> map_to_integer roundtrip with a very + large number, exactly replicating the Rust reference test. + """ + w, v, d = 12, 40, 174 + x = 21790506781852242898091207809690042074412 + + # Map the integer to a vertex. + vertex_a = map_to_vertex(w, v, d, x) + + # Map the vertex back to an integer. + x_reconstructed = map_to_integer(w, v, d, vertex_a) + + # Map the reconstructed integer back to a vertex again. + vertex_b = map_to_vertex(w, v, d, x_reconstructed) + + # Assert that both steps of the roundtrip were successful. + assert x == x_reconstructed + assert vertex_a == vertex_b diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py new file mode 100644 index 00000000..564fbe3a --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py @@ -0,0 +1,104 @@ +""" +Tests for the sparse Merkle tree implementation. + +This module verifies the correctness of the Merkle tree construction, +path generation, and verification logic by performing a full +"commit-open-verify" roundtrip across various tree configurations. +""" + +import pytest + +from lean_spec.subspecs.xmss.merkle_tree import ( + build_tree, + get_path, + get_root, + verify_path, +) +from lean_spec.subspecs.xmss.structures import HashDigest, Parameter +from lean_spec.subspecs.xmss.tweak_hash import ( + TreeTweak, + rand_domain, + rand_parameter, +) +from lean_spec.subspecs.xmss.tweak_hash import ( + apply as apply_tweakable_hash, +) + + +def _run_commit_open_verify_roundtrip( + num_leaves: int, depth: int, start_index: int, leaf_parts_len: int +) -> None: + """ + A helper function to perform a full Merkle tree roundtrip test. + + The process is as follows: + 1. Generate random leaf data. + 2. Hash the leaves to create layer 0 of the tree. + 3. Build the full Merkle tree and get its root (commit). + 4. For each leaf, generate an authentication path (open). + 5. Verify that each path is valid for its corresponding leaf and root. + + Args: + num_leaves: The number of active leaves in the tree. + start_index: The starting index of the first active leaf. + leaf_parts_len: The number of digests that constitute a single leaf. + """ + # SETUP: Generate a random parameter and the raw leaf data. + parameter: Parameter = rand_parameter() + leaves: list[list[HashDigest]] = [ + [rand_domain() for _ in range(leaf_parts_len)] + for _ in range(num_leaves) + ] + + # HASH LEAVES: Compute the layer 0 nodes by hashing the leaf parts. + leaf_hashes: list[HashDigest] = [ + apply_tweakable_hash( + parameter, + TreeTweak(level=0, index=start_index + i), + leaf_parts, + ) + for i, leaf_parts in enumerate(leaves) + ] + + # COMMIT: Build the Merkle tree from the leaf hashes. + tree = build_tree(depth, start_index, parameter, leaf_hashes) + root = get_root(tree) + + # OPEN & VERIFY: For each leaf, generate and verify its path. + for i, leaf_parts in enumerate(leaves): + position = start_index + i + opening = get_path(tree, position) + is_valid = verify_path(parameter, root, position, leaf_parts, opening) + assert is_valid, f"Verification failed for leaf at position {position}" + + +@pytest.mark.parametrize( + "num_leaves, depth, start_index, leaf_parts_len, description", + [ + (16, 4, 0, 3, "Full tree (depth 4)"), + (12, 5, 0, 5, "Half tree, left-aligned (depth 5)"), + (16, 5, 16, 2, "Half tree, right-aligned (depth 5)"), + (22, 6, 13, 3, "Sparse, non-aligned tree (depth 6)"), + (2, 2, 2, 6, "Half tree, right-aligned (small)"), + (1, 1, 0, 1, "Tree with a single leaf at the start"), + (1, 1, 1, 1, "Tree with a single leaf at an odd index"), + (16, 5, 7, 2, "Small sparse tree starting at an odd index"), + ], +) +def test_commit_open_verify_roundtrip( + num_leaves: int, + depth: int, + start_index: int, + leaf_parts_len: int, + description: str, +) -> None: + """ + Tests the Merkle tree logic for various configurations using test-specific + tree depths for efficiency. + """ + # Ensure the test case parameters are valid for the specified tree depth. + assert start_index + num_leaves <= (1 << depth) + + _run_commit_open_verify_roundtrip( + num_leaves, depth, start_index, leaf_parts_len + ) diff --git a/tests/lean_spec/subspecs/xmss/test_prf.py b/tests/lean_spec/subspecs/xmss/test_prf.py new file mode 100644 index 00000000..236d8cdf --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_prf.py @@ -0,0 +1,67 @@ +""" +Tests for the SHAKE128-based pseudorandom function (PRF). +""" + +from lean_spec.subspecs.xmss.constants import HASH_LEN_FE, PRF_KEY_LENGTH +from lean_spec.subspecs.xmss.prf import apply, key_gen + + +def test_key_gen_is_random() -> None: + """ + Performs a sanity check on `key_gen` to ensure it's not deterministic + or producing trivial outputs. + + This test mirrors the logic from the reference Rust implementation. + """ + # Check that the key has the correct length. + key = key_gen() + assert len(key) == PRF_KEY_LENGTH + + # Generate multiple keys and ensure they are not all identical. + # + # This is a basic check to ensure we are getting fresh randomness. + num_trials = 10 + keys = {key_gen() for _ in range(num_trials)} + assert len(keys) == num_trials + + # Check that the keys are not filled with a single repeated byte. + # + # It is astronomically unlikely for a secure random generator to produce + # such a key, so this is a good health check. + all_same_count = 0 + for _ in range(num_trials): + key = key_gen() + # A set will have size 1 if all elements are the same. + if len(set(key)) == 1: + all_same_count += 1 + assert all_same_count < num_trials, "key_gen produced non-random keys" + + +def test_apply_is_sensitive_to_inputs() -> None: + """ + Tests that changing any input to `apply` results in a different output. + + This confirms that all parts of the input (key, epoch, chain_index) are + being correctly absorbed by the hash function. + """ + # Generate a baseline output with a set of initial inputs. + key1 = b"\x11" * PRF_KEY_LENGTH + epoch1 = 10 + chain_index1 = 20 + baseline_output = apply(key1, epoch1, chain_index1) + assert len(baseline_output) == HASH_LEN_FE + + # Test sensitivity to the key. + key2 = b"\x22" * PRF_KEY_LENGTH + output_key_changed = apply(key2, epoch1, chain_index1) + assert baseline_output != output_key_changed + + # Test sensitivity to the epoch. + epoch2 = 11 + output_epoch_changed = apply(key1, epoch2, chain_index1) + assert baseline_output != output_epoch_changed + + # Test sensitivity to the chain_index. + chain_index2 = 21 + output_index_changed = apply(key1, epoch1, chain_index2) + assert baseline_output != output_index_changed From 8b62fb3a9b1d7ef709eccaa506cfd4734a6cbc81 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 15:15:56 +0200 Subject: [PATCH 02/13] add message hash and tests --- src/lean_spec/subspecs/xmss/message_hash.py | 135 ++++++++++++++++++ src/lean_spec/subspecs/xmss/utils.py | 29 ++++ .../subspecs/xmss/test_message_hash.py | 99 +++++++++++++ 3 files changed, 263 insertions(+) create mode 100644 src/lean_spec/subspecs/xmss/message_hash.py create mode 100644 src/lean_spec/subspecs/xmss/utils.py create mode 100644 tests/lean_spec/subspecs/xmss/test_message_hash.py diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py new file mode 100644 index 00000000..eb35550f --- /dev/null +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -0,0 +1,135 @@ +""" +Defines the "Top Level" message hashing for the signature scheme. + +This module implements the logic for hashing a message into the top layers of +a hypercube, which is the first step in the "Top Level Target Sum" encoding. + +The process involves: +1. Encoding the message, epoch, and randomness into field elements. +2. Hashing these elements with Poseidon2 to produce a digest. +3. Interpreting the digest as a large integer and mapping it to a unique + vertex within the allowed top layers of the hypercube. +""" + +from __future__ import annotations + +from typing import List + +from ..koalabear import Fp, P +from .constants import ( + BASE, + DIMENSION, + FINAL_LAYER, + MSG_LEN_FE, + POS_INVOCATIONS, + POS_OUTPUT_LEN_PER_INV_FE, + TWEAK_LEN_FE, + TWEAK_PREFIX_MESSAGE, +) +from .hypercube import find_layer, get_hypercube_part_size, map_to_vertex +from .structures import Parameter, Randomness +from .tweak_hash import poseidon_compress + + +def encode_message(message: bytes) -> List[Fp]: + """ + Encodes a 32-byte message into a list of field elements. + + The message bytes are interpreted as a single little-endian integer, + which is then decomposed into its base-`P` representation. + """ + # Interpret the little-endian bytes as a single large integer. + acc = int.from_bytes(message, "little") + + # Decompose the integer into a list of field elements (base-P). + elements: List[Fp] = [] + for _ in range(MSG_LEN_FE): + elements.append(Fp(value=acc)) + acc //= P + return elements + + +def encode_epoch(epoch: int) -> List[Fp]: + """Encodes epoch and domain separator into a list of field elements.""" + # Combine the epoch and the message hash prefix into a single integer. + acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value + + # Decompose the integer into its base-P representation. + elements: List[Fp] = [] + for _ in range(TWEAK_LEN_FE): + elements.append(Fp(value=acc % P)) + acc //= P + return elements + + +def _map_into_hypercube_part(field_elements: List[Fp]) -> List[int]: + """ + Maps a list of field elements to a vertex in + the top `FINAL_LAYER` layers. + """ + # Combine field elements into one large integer (big-endian, base-P). + acc = 0 + for fe in field_elements: + acc = acc * P + fe.value + + # Reduce this integer modulo the size of the target domain. + # + # The target domain is the set of all vertices in layers 0..FINAL_LAYER. + domain_size = get_hypercube_part_size(BASE, DIMENSION, FINAL_LAYER) + acc %= domain_size + + # Find which layer the resulting index falls into, and its offset. + layer, offset = find_layer(BASE, DIMENSION, acc) + + # Map the offset within the layer to a unique vertex. + return map_to_vertex(BASE, DIMENSION, layer, offset) + + +def apply( + parameter: Parameter, + epoch: int, + randomness: Randomness, + message: bytes, +) -> List[int]: + """ + Applies the full "Top Level" message hash procedure. + + This involves multiple invocations of Poseidon2, with the combined output + mapped into a specific region of the hypercube. + + Args: + parameter: The public parameter `P`. + epoch: The current epoch. + randomness: A random value `rho`. + message: The 32-byte message to be hashed. + + Returns: + A vertex in the hypercube, represented as a list of `DIMENSION` ints. + """ + # Encode the message and epoch as field elements. + message_fe = encode_message(message) + epoch_fe = encode_epoch(epoch) + + # Iteratively call Poseidon2 to generate a long hash output. + poseidon_outputs: List[Fp] = [] + for i in range(POS_INVOCATIONS): + # Use the iteration number as a domain separator for each hash call. + iteration_separator = [Fp(value=i)] + + # The input is: rho || P || epoch || message || iteration. + combined_input = ( + randomness + + parameter + + epoch_fe + + message_fe + + iteration_separator + ) + + # Hash the combined input using Poseidon2 compression mode. + iteration_output = poseidon_compress( + combined_input, 24, POS_OUTPUT_LEN_PER_INV_FE + ) + poseidon_outputs.extend(iteration_output) + + # Map the final list of field elements into a hypercube vertex. + return _map_into_hypercube_part(poseidon_outputs) diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py new file mode 100644 index 00000000..9c04efc1 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -0,0 +1,29 @@ +"""Utility functions for the XMSS signature scheme.""" + +import secrets +from typing import List + +from ..koalabear import Fp, P +from .constants import HASH_LEN_FE, PARAMETER_LEN, RAND_LEN_FE +from .structures import HashDigest, Parameter, Randomness + + +def rand_field_elements(length: int) -> List[Fp]: + """Generates a random list of field elements.""" + # For each element, generate a secure random integer in the range [0, P-1]. + return [Fp(value=secrets.randbelow(P)) for _ in range(length)] + + +def rand_parameter() -> Parameter: + """Generates a random public parameter.""" + return rand_field_elements(PARAMETER_LEN) + + +def rand_domain() -> HashDigest: + """Generates a random hash digest.""" + return rand_field_elements(HASH_LEN_FE) + + +def rand_rho() -> Randomness: + """Generates randomness `rho` for message encoding.""" + return rand_field_elements(RAND_LEN_FE) diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py new file mode 100644 index 00000000..24d007a1 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_message_hash.py @@ -0,0 +1,99 @@ +""" +Tests for the "Top Level" message hashing and encoding logic. +""" + +from typing import List + +from lean_spec.subspecs.koalabear import Fp, P +from lean_spec.subspecs.xmss.constants import ( + BASE, + DIMENSION, + FINAL_LAYER, + MESSAGE_LENGTH, + MSG_LEN_FE, + TWEAK_LEN_FE, + TWEAK_PREFIX_MESSAGE, +) +from lean_spec.subspecs.xmss.message_hash import ( + apply, + encode_epoch, + encode_message, +) +from lean_spec.subspecs.xmss.tweak_hash import rand_parameter +from lean_spec.subspecs.xmss.utils import rand_rho + + +def test_encode_message() -> None: + """Tests `encode_message` with various message patterns.""" + # All-zero message + msg_zeros = b"\x00" * MESSAGE_LENGTH + encoded_zeros = encode_message(msg_zeros) + assert len(encoded_zeros) == MSG_LEN_FE + assert all(fe.value == 0 for fe in encoded_zeros) + + # All-max message (0xff) + msg_max = b"\xff" * MESSAGE_LENGTH + acc = int.from_bytes(msg_max, "little") + expected_max: List[Fp] = [] + for _ in range(MSG_LEN_FE): + expected_max.append(Fp(value=acc)) + acc //= P + assert encode_message(msg_max) == expected_max + + +def test_encode_epoch() -> None: + """ + Tests `encode_epoch` for correctness and injectivity. + """ + # Test specific values from the Rust reference tests. + test_epochs = [0, 42, 2**32 - 1] + for epoch in test_epochs: + acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value + expected: List[Fp] = [] + for _ in range(TWEAK_LEN_FE): + expected.append(Fp(value=acc)) + acc //= P + assert encode_epoch(epoch) == expected + + # Test for injectivity. It is highly unlikely for a collision to occur + # with a few random samples if the encoding is injective. + num_trials = 1000 + seen_encodings: set[tuple[Fp, ...]] = set() + for i in range(num_trials): + encoding = tuple(encode_epoch(i)) + assert encoding not in seen_encodings + seen_encodings.add(encoding) + + +def test_apply_output_is_in_correct_hypercube_part() -> None: + """ + Tests that the output of `apply` is a valid vertex that lies within + the top `FINAL_LAYER` layers of the hypercube. + """ + # Setup with random inputs. + parameter = rand_parameter() + epoch = 313 + randomness = rand_rho() + message = b"\xaa" * MESSAGE_LENGTH + + # Call the message hash function. + vertex = apply(parameter, epoch, randomness, message) + + # Verify the properties of the output vertex. + # + # The length of the vertex must be equal to the hypercube's dimension. + assert len(vertex) == DIMENSION + # Each coordinate must be smaller than the base `w`. + assert all(0 <= coord < BASE for coord in vertex) + + # Check that the vertex lies in the correct set of layers. + # + # By definition, a vertex is in layer `d` if `d = v*(w-1) - sum(coords)`. + # + # We require `d <= FINAL_LAYER`. + # + # This is equivalent to `sum(coords) >= v*(w-1) - FINAL_LAYER`. + coord_sum = sum(vertex) + min_required_sum = (BASE - 1) * DIMENSION - FINAL_LAYER + + assert coord_sum >= min_required_sum, "Vertex is not in the top layers" From 337d072bf45f71c87abde7acfa5183e1f522faea Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 15:18:47 +0200 Subject: [PATCH 03/13] utils: rm duplicate functions --- src/lean_spec/subspecs/xmss/interface.py | 2 +- src/lean_spec/subspecs/xmss/merkle_tree.py | 3 +- src/lean_spec/subspecs/xmss/tweak_hash.py | 30 ------------------- .../subspecs/xmss/test_merkle_tree.py | 6 ++-- .../subspecs/xmss/test_message_hash.py | 3 +- 5 files changed, 8 insertions(+), 36 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 89aa19a8..2fcacbd1 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -18,11 +18,11 @@ from .tweak_hash import ( TreeTweak, hash_chain, - rand_parameter, ) from .tweak_hash import ( apply as apply_tweakable_hash, ) +from .utils import rand_parameter def key_gen( diff --git a/src/lean_spec/subspecs/xmss/merkle_tree.py b/src/lean_spec/subspecs/xmss/merkle_tree.py index d6eaff4a..09db20e4 100644 --- a/src/lean_spec/subspecs/xmss/merkle_tree.py +++ b/src/lean_spec/subspecs/xmss/merkle_tree.py @@ -25,8 +25,9 @@ HashTreeOpening, Parameter, ) -from .tweak_hash import TreeTweak, rand_domain +from .tweak_hash import TreeTweak from .tweak_hash import apply as apply_tweakable_hash +from .utils import rand_domain def _get_padded_layer( diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index 67f18e79..24508923 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -11,7 +11,6 @@ from __future__ import annotations -import secrets from itertools import chain from typing import List, Union @@ -296,32 +295,3 @@ def hash_chain( # Apply the hash function. current_digest = apply(parameter, tweak, [current_digest]) return current_digest - - -def rand_parameter() -> Parameter: - """ - Generates a cryptographically secure random public parameter. - - Returns: - A new, randomly generated list of `PARAMETER_LEN` field elements. - """ - # For each element in the list, generate a secure random integer - # in the range [0, P-1] and convert it to a field element. - # `secrets.randbelow(P)` is used to avoid modulo bias. - return [Fp(value=secrets.randbelow(P)) for _ in range(PARAMETER_LEN)] - - -def rand_domain() -> HashDigest: - """ - Generates a cryptographically secure random hash digest. - - This is used for testing or as a starting point for hash chains - where a random seed is required. - - Returns: - A new, randomly generated list of `HASH_LEN_FE` field elements. - """ - # For each element in the list, generate a secure random integer - # in the range [0, P-1] and convert it to a field element. - # `secrets.randbelow(P)` is used to avoid modulo bias. - return [Fp(value=secrets.randbelow(P)) for _ in range(HASH_LEN_FE)] diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py index 564fbe3a..1d332f14 100644 --- a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py +++ b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py @@ -17,12 +17,14 @@ from lean_spec.subspecs.xmss.structures import HashDigest, Parameter from lean_spec.subspecs.xmss.tweak_hash import ( TreeTweak, - rand_domain, - rand_parameter, ) from lean_spec.subspecs.xmss.tweak_hash import ( apply as apply_tweakable_hash, ) +from lean_spec.subspecs.xmss.utils import ( + rand_domain, + rand_parameter, +) def _run_commit_open_verify_roundtrip( diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py index 24d007a1..6029fd3b 100644 --- a/tests/lean_spec/subspecs/xmss/test_message_hash.py +++ b/tests/lean_spec/subspecs/xmss/test_message_hash.py @@ -19,8 +19,7 @@ encode_epoch, encode_message, ) -from lean_spec.subspecs.xmss.tweak_hash import rand_parameter -from lean_spec.subspecs.xmss.utils import rand_rho +from lean_spec.subspecs.xmss.utils import rand_parameter, rand_rho def test_encode_message() -> None: From 4322e87886a73c848a195830d32b3dd13c7b766b Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 15:20:22 +0200 Subject: [PATCH 04/13] small touchup --- src/lean_spec/subspecs/xmss/message_hash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index eb35550f..4161ffc9 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -57,7 +57,7 @@ def encode_epoch(epoch: int) -> List[Fp]: # Decompose the integer into its base-P representation. elements: List[Fp] = [] for _ in range(TWEAK_LEN_FE): - elements.append(Fp(value=acc % P)) + elements.append(Fp(value=acc)) acc //= P return elements From be417d8e7beef1489123369a5d52d8616c8ac3a5 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 15:22:10 +0200 Subject: [PATCH 05/13] small doc touchup --- src/lean_spec/subspecs/xmss/hypercube.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/hypercube.py b/src/lean_spec/subspecs/xmss/hypercube.py index a553480d..42450857 100644 --- a/src/lean_spec/subspecs/xmss/hypercube.py +++ b/src/lean_spec/subspecs/xmss/hypercube.py @@ -83,7 +83,8 @@ def prepare_layer_info(w: int) -> List[LayerInfo]: # Index 0 is unused to allow for direct indexing, e.g., `all_info[v]`. all_info = [LayerInfo(sizes=[], prefix_sums=[])] * (MAX_DIMENSION + 1) - # -- BASE CASE -- + # BASE CASE + # # For a 1-dimensional hypercube (v=1), which is just a line of `w` points # with coordinates [0], [1], ..., [w-1]. # The distance `d` from the sink `[w-1]` is simply `(w-1) - coordinate`. @@ -95,7 +96,8 @@ def prepare_layer_info(w: int) -> List[LayerInfo]: # Store the result for v=1, which will seed the inductive step. all_info[1] = LayerInfo(sizes=dim1_sizes, prefix_sums=dim1_prefix_sums) - # -- INDUCTIVE STEP -- + # INDUCTIVE STEP + # # Now, build the layer info for all higher dimensions up to the maximum. for v in range(2, MAX_DIMENSION + 1): # The maximum possible distance `d` in a v-dimensional hypercube. From d1496a56fbfd2b424b7e6a2b7d5634340a7f5460 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 15:34:39 +0200 Subject: [PATCH 06/13] complete the interface --- src/lean_spec/subspecs/xmss/interface.py | 171 +++++++++++----------- src/lean_spec/subspecs/xmss/target_sum.py | 49 +++++++ 2 files changed, 132 insertions(+), 88 deletions(-) create mode 100644 src/lean_spec/subspecs/xmss/target_sum.py diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 2fcacbd1..ae39610b 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -1,17 +1,19 @@ """ Defines the core interface for the Generalized XMSS signature scheme. -Specification for the high-level functions (`key_gen`, `sign`, `verify`) -that constitute the public API of the signature scheme. For the purpose of this -specification, these are defined as placeholders with detailed documentation. +Specification for the high-level functions (`key_gen`, `sign`, `verify`). + +This constitutes the public API of the signature scheme. """ from __future__ import annotations from typing import List, Tuple -from .constants import BASE, DIMENSION, LIFETIME, LOG_LIFETIME -from .merkle_tree import build_tree, get_root +from lean_spec.subspecs.xmss.target_sum import encode + +from .constants import BASE, DIMENSION, LIFETIME, LOG_LIFETIME, MAX_TRIES +from .merkle_tree import build_tree, get_path, get_root, verify_path from .prf import apply as apply_prf from .prf import key_gen as prf_key_gen from .structures import HashDigest, PublicKey, SecretKey, Signature @@ -22,7 +24,7 @@ from .tweak_hash import ( apply as apply_tweakable_hash, ) -from .utils import rand_parameter +from .utils import rand_parameter, rand_rho def key_gen( @@ -122,53 +124,49 @@ def sign(sk: SecretKey, epoch: int, message: bytes) -> Signature: - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - The canonical Rust implementation: https://github.com/b-wagn/hash-sig """ - # # 1. Check that the key is active for the requested signing epoch. - # active_range = range( - # sk.activation_epoch, sk.activation_epoch + sk.num_active_epochs - # ) - # if epoch not in active_range: - # raise ValueError("Key is not active for the specified epoch.") - - # # 2. Find a valid message encoding by trying different randomness `rho`. - # codeword = None - # rho = None - # for _ in range(MAX_TRIES): - # current_rho = rand_rho() - # current_codeword = encode(sk.parameter, message, current_rho, epoch) - # if current_codeword is not None: - # codeword = current_codeword - # rho = current_rho - # break - - # # If no valid encoding is found after many tries, signing fails. - # if codeword is None or rho is None: - # raise RuntimeError("Failed to find a valid message encoding.") - - # # 3. Compute the one-time signature hashes based on the codeword. - # ots_hashes: List[HashDigest] = [] - # for chain_index, steps in enumerate(codeword): - # # Derive the secret start of the chain from the master key. - # start_digest = apply_prf(sk.prf_key, epoch, chain_index) - # # Walk the chain for the number of steps given by the codeword digit. - # ots_digest = hash_chain( - # parameter=sk.parameter, - # epoch=epoch, - # chain_index=chain_index, - # start_step=0, - # num_steps=steps, - # start_digest=start_digest, - # ) - # ots_hashes.append(ots_digest) - - # # 4. Get the Merkle authentication path for the current epoch. - # path = get_path(sk.tree, epoch) - - # # 5. Assemble and return the final signature. - # return Signature(path=path, rho=rho, hashes=ots_hashes) - raise NotImplementedError( - "sign is not part of this specification. " - "See the Rust reference implementation." + # Check that the key is active for the requested signing epoch. + active_range = range( + sk.activation_epoch, sk.activation_epoch + sk.num_active_epochs ) + if epoch not in active_range: + raise ValueError("Key is not active for the specified epoch.") + + # Find a valid message encoding by trying different randomness `rho`. + codeword = None + rho = None + for _ in range(MAX_TRIES): + current_rho = rand_rho() + current_codeword = encode(sk.parameter, message, current_rho, epoch) + if current_codeword is not None: + codeword = current_codeword + rho = current_rho + break + + # If no valid encoding is found after many tries, signing fails. + if codeword is None or rho is None: + raise RuntimeError("Failed to find a valid message encoding.") + + # Compute the one-time signature hashes based on the codeword. + ots_hashes: List[HashDigest] = [] + for chain_index, steps in enumerate(codeword): + # Derive the secret start of the chain from the master key. + start_digest = apply_prf(sk.prf_key, epoch, chain_index) + # Walk the chain for the number of steps given by the codeword digit. + ots_digest = hash_chain( + parameter=sk.parameter, + epoch=epoch, + chain_index=chain_index, + start_step=0, + num_steps=steps, + start_digest=start_digest, + ) + ots_hashes.append(ots_digest) + + # Get the Merkle authentication path for the current epoch. + path = get_path(sk.tree, epoch) + + # Assemble and return the final signature. + return Signature(path=path, rho=rho, hashes=ots_hashes) def verify(pk: PublicKey, epoch: int, message: bytes, sig: Signature) -> bool: @@ -214,40 +212,37 @@ def verify(pk: PublicKey, epoch: int, message: bytes, sig: Signature) -> bool: - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - The canonical Rust implementation: https://github.com/b-wagn/hash-sig """ - # # 1. Re-encode the message using the randomness `rho` from the signature. - # # If the encoding is invalid, the signature is invalid. - # codeword = encode(pk.parameter, message, sig.rho, epoch) - # if codeword is None: - # return False - - # # 2. Reconstruct the one-time public key (the list of chain endpoints). - # chain_ends: List[HashDigest] = [] - # for chain_index, xi in enumerate(codeword): - # # The signature provides the hash value after `xi` steps. - # start_digest = sig.hashes[chain_index] - # # We must perform the remaining `BASE - 1 - xi` steps - # # to get to the end. - # num_steps_remaining = BASE - 1 - xi - # end_digest = hash_chain( - # parameter=pk.parameter, - # epoch=epoch, - # chain_index=chain_index, - # start_step=xi, - # num_steps=num_steps_remaining, - # start_digest=start_digest, - # ) - # chain_ends.append(end_digest) - - # # 3. Verify the Merkle path. This function internally hashes `chain_ends` - # # to get the leaf node and then climbs the tree to recompute the root. - # return verify_path( - # parameter=pk.parameter, - # root=pk.root, - # position=epoch, - # leaf_parts=chain_ends, - # opening=sig.path, - # ) - raise NotImplementedError( - "sign is not part of this specification. " - "See the Rust reference implementation." + # Re-encode the message using the randomness `rho` from the signature. + # + # If the encoding is invalid, the signature is invalid. + codeword = encode(pk.parameter, message, sig.rho, epoch) + if codeword is None: + return False + + # Reconstruct the one-time public key (the list of chain endpoints). + chain_ends: List[HashDigest] = [] + for chain_index, xi in enumerate(codeword): + # The signature provides the hash value after `xi` steps. + start_digest = sig.hashes[chain_index] + # We must perform the remaining `BASE - 1 - xi` steps + # to get to the end. + num_steps_remaining = BASE - 1 - xi + end_digest = hash_chain( + parameter=pk.parameter, + epoch=epoch, + chain_index=chain_index, + start_step=xi, + num_steps=num_steps_remaining, + start_digest=start_digest, + ) + chain_ends.append(end_digest) + + # Verify the Merkle path. This function internally hashes `chain_ends` + # to get the leaf node and then climbs the tree to recompute the root. + return verify_path( + parameter=pk.parameter, + root=pk.root, + position=epoch, + leaf_parts=chain_ends, + opening=sig.path, ) diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py new file mode 100644 index 00000000..66251f5b --- /dev/null +++ b/src/lean_spec/subspecs/xmss/target_sum.py @@ -0,0 +1,49 @@ +""" +Implements the Top Level Target Sum Winternitz incomparable encoding scheme. + +This module provides the logic for converting a message hash into a valid +codeword for the one-time signature part of the scheme. It acts as a filter on +top of the message hash output. +""" + +from typing import List, Optional + +from .constants import TARGET_SUM +from .message_hash import apply as apply_message_hash +from .structures import Parameter, Randomness + + +def encode( + parameter: Parameter, message: bytes, rho: Randomness, epoch: int +) -> Optional[List[int]]: + """ + Encodes a message into a codeword if it meets the target sum criteria. + + This function first uses the message hash to map the input to a vertex in + the hypercube. It then checks if the sum of the vertex's coordinates + matches the scheme's `TARGET_SUM`. This filtering step is the core of the + Target Sum scheme. + + Args: + parameter: The public parameter `P`. + message: The message to encode. + rho: The randomness used for this encoding attempt. + epoch: The current epoch. + + Returns: + The codeword (a list of integers) if the sum is correct, + otherwise `None`. + """ + # Apply the message hash to get a potential codeword (a vertex). + codeword_candidate = apply_message_hash(parameter, epoch, rho, message) + + # Check if the candidate satisfies the target sum condition. + if sum(codeword_candidate) == TARGET_SUM: + # If the sum is correct, this is a valid codeword. + return codeword_candidate + else: + # If the sum does not match, this `rho` is invalid for this message. + # + # The caller (the `sign` function) will need to try again with new + # randomness. + return None From 6983c7405da9c997456b8d5e7b7c2ab36cd67b22 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 17:51:52 +0200 Subject: [PATCH 07/13] update using classes --- src/lean_spec/subspecs/xmss/__init__.py | 14 +- src/lean_spec/subspecs/xmss/constants.py | 167 +++--- src/lean_spec/subspecs/xmss/interface.py | 511 ++++++++++-------- src/lean_spec/subspecs/xmss/merkle_tree.py | 406 +++++++------- src/lean_spec/subspecs/xmss/message_hash.py | 239 ++++---- src/lean_spec/subspecs/xmss/poseidon.py | 136 +++++ src/lean_spec/subspecs/xmss/prf.py | 140 ++--- src/lean_spec/subspecs/xmss/structures.py | 24 +- src/lean_spec/subspecs/xmss/target_sum.py | 93 ++-- src/lean_spec/subspecs/xmss/tweak_hash.py | 420 ++++++-------- src/lean_spec/subspecs/xmss/utils.py | 40 +- .../subspecs/xmss/test_merkle_tree.py | 39 +- .../subspecs/xmss/test_message_hash.py | 57 +- tests/lean_spec/subspecs/xmss/test_prf.py | 28 +- 14 files changed, 1279 insertions(+), 1035 deletions(-) create mode 100644 src/lean_spec/subspecs/xmss/poseidon.py diff --git a/src/lean_spec/subspecs/xmss/__init__.py b/src/lean_spec/subspecs/xmss/__init__.py index b992bfe6..2ea7484a 100644 --- a/src/lean_spec/subspecs/xmss/__init__.py +++ b/src/lean_spec/subspecs/xmss/__init__.py @@ -5,9 +5,7 @@ It exposes the core data structures and the main interface functions. """ -from .constants import LIFETIME, MESSAGE_LENGTH -from .interface import key_gen, sign, verify -from .merkle_tree import build_tree, get_path, get_root, verify_path +from .interface import GeneralizedXmssScheme from .structures import ( HashTree, HashTreeOpening, @@ -17,18 +15,10 @@ ) __all__ = [ - "key_gen", - "sign", - "verify", + "GeneralizedXmssScheme", "PublicKey", "Signature", "SecretKey", "HashTreeOpening", - "LIFETIME", - "MESSAGE_LENGTH", - "build_tree", - "get_path", - "get_root", - "verify_path", "HashTree", ] diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index eea67dc5..b847c4f6 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -1,5 +1,5 @@ """ -Defines the cryptographic constants for the XMSS specification. +Defines the cryptographic constants and configuration presets for the XMSS spec. This specification corresponds to the "hashing-optimized" Top Level Target Sum instantiation from the canonical Rust implementation. @@ -14,102 +14,141 @@ specification in the future. """ -from ..koalabear import Fp - -PRF_KEY_LENGTH: int = 32 -"""The length of the PRF secret key in bytes.""" +from typing import TypeVar +from pydantic import BaseModel, ConfigDict +from typing_extensions import Final -MAX_TRIES: int = 100_000 -""" -How often one should try at most to resample a random value. +from ..koalabear import Fp -This is currently based on experiments with the Rust implementation. -Should probably be modified in production. -""" +class XmssConfig(BaseModel): + """A model holding the configuration constants for an XMSS preset.""" + model_config = ConfigDict(frozen=True, extra="forbid") -# ================================================================= -# Core Scheme Configuration -# ================================================================= + # --- Core Scheme Configuration --- + MESSAGE_LENGTH: int + """The length in bytes for all messages to be signed.""" -MESSAGE_LENGTH: int = 32 -"""The length in bytes for all messages to be signed.""" + LOG_LIFETIME: int + """The base-2 logarithm of the scheme's maximum lifetime.""" -LOG_LIFETIME: int = 32 -"""The base-2 logarithm of the scheme's maximum lifetime.""" + @property + def LIFETIME(self) -> int: + """ + The maximum number of epochs supported by this configuration. -LIFETIME: int = 1 << LOG_LIFETIME -""" -The maximum number of epochs supported by this configuration. + An individual key pair can be active for a smaller sub-range. + """ + return 1 << self.LOG_LIFETIME -An individual key pair can be active for a smaller sub-range. -""" + # --- Target Sum WOTS Parameters --- + DIMENSION: int + """The total number of hash chains, `v`.""" + BASE: int + """The alphabet size for the digits of the encoded message.""" -# ================================================================= -# Target Sum WOTS Parameters -# ================================================================= + FINAL_LAYER: int + """The number of top layers of the hypercube to map the hash output into.""" -DIMENSION: int = 64 -"""The total number of hash chains, `v`.""" + TARGET_SUM: int + """The required sum of all codeword chunks for a signature to be valid.""" -BASE: int = 8 -"""The alphabet size for the digits of the encoded message.""" + MAX_TRIES: int + """ + How often one should try at most to resample a random value. -FINAL_LAYER: int = 77 -"""The number of top layers of the hypercube to map the hash output into.""" + This is currently based on experiments with the Rust implementation. + Should probably be modified in production. + """ -TARGET_SUM: int = 375 -"""The required sum of all codeword chunks for a signature to be valid.""" + # --- Hash and Encoding Length Parameters (in field elements) --- + PARAMETER_LEN: int + """ + The length of the public parameter `P`. -# ================================================================= -# Hash and Encoding Length Parameters (in field elements) -# ================================================================= + It is used to specialize the hash function. + """ -PARAMETER_LEN: int = 5 -""" -The length of the public parameter `P`. + TWEAK_LEN_FE: int + """The length of a domain-separating tweak.""" -It is used to specialize the hash function. -""" + MSG_LEN_FE: int + """The length of a message after being encoded into field elements.""" -TWEAK_LEN_FE: int = 2 -"""The length of a domain-separating tweak.""" + RAND_LEN_FE: int + """The length of the randomness `rho` used during message encoding.""" -MSG_LEN_FE: int = 9 -"""The length of a message after being encoded into field elements.""" + HASH_LEN_FE: int + """The output length of the main tweakable hash function.""" -RAND_LEN_FE: int = 7 -"""The length of the randomness `rho` used during message encoding.""" + CAPACITY: int + """The capacity of the Poseidon2 sponge, defining its security level.""" -HASH_LEN_FE: int = 8 -"""The output length of the main tweakable hash function.""" + POS_OUTPUT_LEN_PER_INV_FE: int + """Output length per invocation for the message hash.""" -CAPACITY: int = 9 -"""The capacity of the Poseidon2 sponge, defining its security level.""" + POS_INVOCATIONS: int + """Number of invocations for the message hash.""" -POS_OUTPUT_LEN_PER_INV_FE: int = 15 -"""Output length per invocation for the message hash.""" + @property + def POS_OUTPUT_LEN_FE(self) -> int: + """Total output length for the message hash.""" + return self.POS_OUTPUT_LEN_PER_INV_FE * self.POS_INVOCATIONS -POS_INVOCATIONS: int = 1 -"""Number of invocations for the message hash.""" -POS_OUTPUT_LEN_FE: int = POS_OUTPUT_LEN_PER_INV_FE * POS_INVOCATIONS -"""Total output length for the message hash.""" +Config = TypeVar("Config", bound=XmssConfig) +"""A type variable representing any XmssConfig instance.""" +PROD_CONFIG: Final = XmssConfig( + MESSAGE_LENGTH=32, + LOG_LIFETIME=32, + DIMENSION=64, + BASE=8, + FINAL_LAYER=77, + TARGET_SUM=375, + MAX_TRIES=100_000, + PARAMETER_LEN=5, + TWEAK_LEN_FE=2, + MSG_LEN_FE=9, + RAND_LEN_FE=7, + HASH_LEN_FE=8, + CAPACITY=9, + POS_OUTPUT_LEN_PER_INV_FE=15, + POS_INVOCATIONS=1, +) -# ================================================================= -# Domain Separator Prefixes for Tweaks -# ================================================================= -TWEAK_PREFIX_CHAIN = Fp(value=0x00) +TEST_CONFIG: Final = XmssConfig( + MESSAGE_LENGTH=32, + LOG_LIFETIME=8, + DIMENSION=16, + BASE=4, + FINAL_LAYER=24, + TARGET_SUM=24, + MAX_TRIES=100_000, + PARAMETER_LEN=5, + TWEAK_LEN_FE=2, + MSG_LEN_FE=9, + RAND_LEN_FE=7, + HASH_LEN_FE=8, + CAPACITY=9, + POS_OUTPUT_LEN_PER_INV_FE=15, + POS_INVOCATIONS=1, +) + + +TWEAK_PREFIX_CHAIN: Final = Fp(value=0x00) """The unique prefix for tweaks used in Winternitz-style hash chains.""" -TWEAK_PREFIX_TREE = Fp(value=0x01) +TWEAK_PREFIX_TREE: Final = Fp(value=0x01) """The unique prefix for tweaks used when hashing Merkle tree nodes.""" -TWEAK_PREFIX_MESSAGE = Fp(value=0x02) +TWEAK_PREFIX_MESSAGE: Final = Fp(value=0x02) """The unique prefix for tweaks used in the initial message hashing step.""" + +PRF_KEY_LENGTH: int = 32 +"""The length of the PRF secret key in bytes.""" diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index ae39610b..2ad456ea 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -10,239 +10,308 @@ from typing import List, Tuple -from lean_spec.subspecs.xmss.target_sum import encode +from lean_spec.subspecs.xmss.target_sum import ( + PROD_TARGET_SUM_ENCODER, + TEST_TARGET_SUM_ENCODER, + TargetSumEncoder, +) -from .constants import BASE, DIMENSION, LIFETIME, LOG_LIFETIME, MAX_TRIES -from .merkle_tree import build_tree, get_path, get_root, verify_path -from .prf import apply as apply_prf -from .prf import key_gen as prf_key_gen +from .constants import ( + PROD_CONFIG, + TEST_CONFIG, + XmssConfig, +) +from .merkle_tree import ( + PROD_MERKLE_TREE, + TEST_MERKLE_TREE, + MerkleTree, +) +from .prf import PROD_PRF, TEST_PRF, Prf from .structures import HashDigest, PublicKey, SecretKey, Signature from .tweak_hash import ( + PROD_TWEAK_HASHER, + TEST_TWEAK_HASHER, TreeTweak, - hash_chain, -) -from .tweak_hash import ( - apply as apply_tweakable_hash, + TweakHasher, ) -from .utils import rand_parameter, rand_rho - - -def key_gen( - activation_epoch: int, num_active_epochs: int -) -> Tuple[PublicKey, SecretKey]: - """ - Generates a new cryptographic key pair. This is a **randomized** algorithm. - - This function executes the full key generation process: - 1. Generates a master secret (PRF key) and a public hash parameter. - 2. For each epoch in the active range, it uses the PRF to derive the - secret starting points for all `DIMENSION` hash chains. - 3. It computes the public endpoint of each chain by - hashing `BASE - 1` times. - 4. The list of all chain endpoints for an epoch forms the - one-time public key. - This one-time public key is hashed to create a single Merkle leaf. - 5. A Merkle tree is built over all generated leaves, and its root becomes - part of the final public key. - - Args: - activation_epoch: The starting epoch for which this key is active. - num_active_epochs: The number of consecutive epochs - the key is active for. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig - """ - # Validate the activation range against the scheme's total lifetime. - if activation_epoch + num_active_epochs > LIFETIME: - raise ValueError("Activation range exceeds the key's lifetime.") - - # Generate the random public parameter `P` and the master PRF key. - parameter = rand_parameter() - prf_key = prf_key_gen() - - # For each epoch, generate the corresponding Merkle leaf hash. - leaf_hashes: List[HashDigest] = [] - for epoch in range(activation_epoch, activation_epoch + num_active_epochs): - # For each epoch, we compute `DIMENSION` chain endpoints. - chain_ends: List[HashDigest] = [] - for chain_index in range(DIMENSION): +from .utils import PROD_RAND, TEST_RAND, Rand + + +class GeneralizedXmssScheme: + """An instance of the Generalized XMSS signature scheme for a given config.""" + + def __init__( + self, + config: XmssConfig, + prf: Prf, + hasher: TweakHasher, + merkle_tree: MerkleTree, + encoder: TargetSumEncoder, + rand: Rand, + ): + """Initializes the scheme with a specific parameter set.""" + self.config = config + self.prf = prf + self.hasher = hasher + self.merkle_tree = merkle_tree + self.encoder = encoder + self.rand = rand + + def key_gen( + self, activation_epoch: int, num_active_epochs: int + ) -> Tuple[PublicKey, SecretKey]: + """ + Generates a new cryptographic key pair. This is a **randomized** algorithm. + + This function executes the full key generation process: + 1. Generates a master secret (PRF key) and a public hash parameter. + 2. For each epoch in the active range, it uses the PRF to derive the + secret starting points for all `DIMENSION` hash chains. + 3. It computes the public endpoint of each chain by + hashing `BASE - 1` times. + 4. The list of all chain endpoints for an epoch forms the + one-time public key. + This one-time public key is hashed to create a single Merkle leaf. + 5. A Merkle tree is built over all generated leaves, and its root becomes + part of the final public key. + + Args: + activation_epoch: The starting epoch for which this key is active. + num_active_epochs: The number of consecutive epochs + the key is active for. + + For the formal specification of this process, please refer to: + - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 + - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 + - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + """ + # Get the config for this scheme. + config = self.config + + # Validate the activation range against the scheme's total lifetime. + if activation_epoch + num_active_epochs > config.LIFETIME: + raise ValueError("Activation range exceeds the key's lifetime.") + + # Generate the random public parameter `P` and the master PRF key. + parameter = self.rand.parameter() + prf_key = self.prf.key_gen() + + # For each epoch, generate the corresponding Merkle leaf hash. + leaf_hashes: List[HashDigest] = [] + for epoch in range( + activation_epoch, activation_epoch + num_active_epochs + ): + # For each epoch, we compute `DIMENSION` chain endpoints. + chain_ends: List[HashDigest] = [] + for chain_index in range(config.DIMENSION): + # Derive the secret start of the chain from the master key. + start_digest = self.prf.apply(prf_key, epoch, chain_index) + # Compute the public end of the chain by hashing + # BASE - 1 times. + end_digest = self.hasher.hash_chain( + parameter=parameter, + epoch=epoch, + chain_index=chain_index, + start_step=0, + num_steps=config.BASE - 1, + start_digest=start_digest, + ) + chain_ends.append(end_digest) + + # The Merkle leaf is the hash of all chain endpoints + # for this epoch. + leaf_tweak = TreeTweak(level=0, index=epoch) + leaf_hash = self.hasher.apply(parameter, leaf_tweak, chain_ends) + leaf_hashes.append(leaf_hash) + + # Build the Merkle tree over the generated leaves. + tree = self.merkle_tree.build( + config.LOG_LIFETIME, activation_epoch, parameter, leaf_hashes + ) + root = self.merkle_tree.get_root(tree) + + # Assemble and return the public and secret keys. + pk = PublicKey(root=root, parameter=parameter) + sk = SecretKey( + prf_key=prf_key, + tree=tree, + parameter=parameter, + activation_epoch=activation_epoch, + num_active_epochs=num_active_epochs, + ) + return pk, sk + + def sign(self, sk: SecretKey, epoch: int, message: bytes) -> Signature: + """ + Produces a digital signature for a given message at a specific epoch. This + is a **randomized** algorithm. + + **CRITICAL**: This function must never be called twice with the same secret + key and epoch for different messages, as this would compromise security. + + The signing process involves: + 1. Repeatedly attempting to encode the message with fresh randomness + (`rho`) until a valid codeword is found. + 2. Computing the one-time signature, which consists of intermediate + values from the secret hash chains, determined by the digits of + the codeword. + 3. Retrieving the Merkle authentication path for the given epoch. + + For the formal specification of this process, please refer to: + - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 + - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 + - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + """ + # Get the config for this scheme. + config = self.config + + # Check that the key is active for the requested signing epoch. + active_range = range( + sk.activation_epoch, sk.activation_epoch + sk.num_active_epochs + ) + if epoch not in active_range: + raise ValueError("Key is not active for the specified epoch.") + + # Find a valid message encoding by trying different randomness `rho`. + codeword = None + rho = None + for _ in range(config.MAX_TRIES): + current_rho = self.rand.rho() + current_codeword = self.encoder.encode( + sk.parameter, message, current_rho, epoch + ) + if current_codeword is not None: + codeword = current_codeword + rho = current_rho + break + + # If no valid encoding is found after many tries, signing fails. + if codeword is None or rho is None: + raise RuntimeError("Failed to find a valid message encoding.") + + # Compute the one-time signature hashes based on the codeword. + ots_hashes: List[HashDigest] = [] + for chain_index, steps in enumerate(codeword): # Derive the secret start of the chain from the master key. - start_digest = apply_prf(prf_key, epoch, chain_index) - # Compute the public end of the chain by hashing BASE - 1 times. - end_digest = hash_chain( - parameter=parameter, + start_digest = self.prf.apply(sk.prf_key, epoch, chain_index) + # Walk the chain for the number of steps given + # by the codeword digit. + ots_digest = self.hasher.hash_chain( + parameter=sk.parameter, epoch=epoch, chain_index=chain_index, start_step=0, - num_steps=BASE - 1, + num_steps=steps, + start_digest=start_digest, + ) + ots_hashes.append(ots_digest) + + # Get the Merkle authentication path for the current epoch. + path = self.merkle_tree.get_path(sk.tree, epoch) + + # Assemble and return the final signature. + return Signature(path=path, rho=rho, hashes=ots_hashes) + + def verify( + self, pk: PublicKey, epoch: int, message: bytes, sig: Signature + ) -> bool: + r""" + Verifies a digital signature against a public key, message, and epoch. This + is a **deterministic** algorithm. + + This function is a placeholder. The complete verification logic is detailed + below and will be implemented in a future update. + + ### Verification Algorithm + + 1. **Re-encode Message**: The verifier uses the randomness `rho` from the + signature to re-compute the codeword $x = (x_1, \dots, x_v)$ from the + message `m`. + This includes calculating the checksum or checking the target sum. + + 2. **Reconstruct One-Time Public Key**: For each intermediate hash $y_i$ + in the signature, the verifier completes the corresponding hash chain. + Since $y_i$ was computed with $x_i$ steps, the verifier applies the + hash function an additional $w - 1 - x_i$ times to arrive at the + one-time public key component $pk_{ep,i}$. + + 3. **Compute Merkle Leaf**: The verifier hashes the reconstructed one-time + public key components to compute the expected Merkle leaf for `epoch`. + + 4. **Verify Merkle Path**: The verifier uses the `path` from the signature + to compute a candidate Merkle root starting from the computed leaf. + Verification succeeds if and only if this candidate root matches the + `root` in the `PublicKey`. + + Args: + pk: The public key to verify against. + epoch: The epoch the signature corresponds to. + message: The message that was supposedly signed. + sig: The signature object to be verified. + + Returns: + `True` if the signature is valid, `False` otherwise. + + For the formal specification of this process, please refer to: + - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 + - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 + - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + """ + # Get the config for this scheme. + config = self.config + + # Re-encode the message using the randomness `rho` from the signature. + # + # If the encoding is invalid, the signature is invalid. + codeword = self.encoder.encode(pk.parameter, message, sig.rho, epoch) + if codeword is None: + return False + + # Reconstruct the one-time public key (the list of chain endpoints). + chain_ends: List[HashDigest] = [] + for chain_index, xi in enumerate(codeword): + # The signature provides the hash value after `xi` steps. + start_digest = sig.hashes[chain_index] + # We must perform the remaining `BASE - 1 - xi` steps + # to get to the end. + num_steps_remaining = config.BASE - 1 - xi + end_digest = self.hasher.hash_chain( + parameter=pk.parameter, + epoch=epoch, + chain_index=chain_index, + start_step=xi, + num_steps=num_steps_remaining, start_digest=start_digest, ) chain_ends.append(end_digest) - # The Merkle leaf is the hash of all chain endpoints for this epoch. - leaf_tweak = TreeTweak(level=0, index=epoch) - leaf_hash = apply_tweakable_hash(parameter, leaf_tweak, chain_ends) - leaf_hashes.append(leaf_hash) - - # Build the Merkle tree over the generated leaves. - tree = build_tree(LOG_LIFETIME, activation_epoch, parameter, leaf_hashes) - root = get_root(tree) - - # Assemble and return the public and secret keys. - pk = PublicKey(root=root, parameter=parameter) - sk = SecretKey( - prf_key=prf_key, - tree=tree, - parameter=parameter, - activation_epoch=activation_epoch, - num_active_epochs=num_active_epochs, - ) - return pk, sk - - -def sign(sk: SecretKey, epoch: int, message: bytes) -> Signature: - """ - Produces a digital signature for a given message at a specific epoch. This - is a **randomized** algorithm. - - **CRITICAL**: This function must never be called twice with the same secret - key and epoch for different messages, as this would compromise security. - - The signing process involves: - 1. Repeatedly attempting to encode the message with fresh randomness - (`rho`) until a valid codeword is found. - 2. Computing the one-time signature, which consists of intermediate - values from the secret hash chains, determined by the digits of - the codeword. - 3. Retrieving the Merkle authentication path for the given epoch. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig - """ - # Check that the key is active for the requested signing epoch. - active_range = range( - sk.activation_epoch, sk.activation_epoch + sk.num_active_epochs - ) - if epoch not in active_range: - raise ValueError("Key is not active for the specified epoch.") - - # Find a valid message encoding by trying different randomness `rho`. - codeword = None - rho = None - for _ in range(MAX_TRIES): - current_rho = rand_rho() - current_codeword = encode(sk.parameter, message, current_rho, epoch) - if current_codeword is not None: - codeword = current_codeword - rho = current_rho - break - - # If no valid encoding is found after many tries, signing fails. - if codeword is None or rho is None: - raise RuntimeError("Failed to find a valid message encoding.") - - # Compute the one-time signature hashes based on the codeword. - ots_hashes: List[HashDigest] = [] - for chain_index, steps in enumerate(codeword): - # Derive the secret start of the chain from the master key. - start_digest = apply_prf(sk.prf_key, epoch, chain_index) - # Walk the chain for the number of steps given by the codeword digit. - ots_digest = hash_chain( - parameter=sk.parameter, - epoch=epoch, - chain_index=chain_index, - start_step=0, - num_steps=steps, - start_digest=start_digest, - ) - ots_hashes.append(ots_digest) - - # Get the Merkle authentication path for the current epoch. - path = get_path(sk.tree, epoch) - - # Assemble and return the final signature. - return Signature(path=path, rho=rho, hashes=ots_hashes) - - -def verify(pk: PublicKey, epoch: int, message: bytes, sig: Signature) -> bool: - r""" - Verifies a digital signature against a public key, message, and epoch. This - is a **deterministic** algorithm. - - This function is a placeholder. The complete verification logic is detailed - below and will be implemented in a future update. - - ### Verification Algorithm - - 1. **Re-encode Message**: The verifier uses the randomness `rho` from the - signature to re-compute the codeword $x = (x_1, \dots, x_v)$ from the - message `m`. - This includes calculating the checksum or checking the target sum. - - 2. **Reconstruct One-Time Public Key**: For each intermediate hash $y_i$ - in the signature, the verifier completes the corresponding hash chain. - Since $y_i$ was computed with $x_i$ steps, the verifier applies the - hash function an additional $w - 1 - x_i$ times to arrive at the - one-time public key component $pk_{ep,i}$. - - 3. **Compute Merkle Leaf**: The verifier hashes the reconstructed one-time - public key components to compute the expected Merkle leaf for `epoch`. - - 4. **Verify Merkle Path**: The verifier uses the `path` from the signature - to compute a candidate Merkle root starting from the computed leaf. - Verification succeeds if and only if this candidate root matches the - `root` in the `PublicKey`. - - Args: - pk: The public key to verify against. - epoch: The epoch the signature corresponds to. - message: The message that was supposedly signed. - sig: The signature object to be verified. - - Returns: - `True` if the signature is valid, `False` otherwise. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig - """ - # Re-encode the message using the randomness `rho` from the signature. - # - # If the encoding is invalid, the signature is invalid. - codeword = encode(pk.parameter, message, sig.rho, epoch) - if codeword is None: - return False - - # Reconstruct the one-time public key (the list of chain endpoints). - chain_ends: List[HashDigest] = [] - for chain_index, xi in enumerate(codeword): - # The signature provides the hash value after `xi` steps. - start_digest = sig.hashes[chain_index] - # We must perform the remaining `BASE - 1 - xi` steps - # to get to the end. - num_steps_remaining = BASE - 1 - xi - end_digest = hash_chain( + # Verify the Merkle path. This function internally hashes `chain_ends` + # to get the leaf node and then climbs the tree to recompute the root. + return self.merkle_tree.verify_path( parameter=pk.parameter, - epoch=epoch, - chain_index=chain_index, - start_step=xi, - num_steps=num_steps_remaining, - start_digest=start_digest, + root=pk.root, + position=epoch, + leaf_parts=chain_ends, + opening=sig.path, ) - chain_ends.append(end_digest) - - # Verify the Merkle path. This function internally hashes `chain_ends` - # to get the leaf node and then climbs the tree to recompute the root. - return verify_path( - parameter=pk.parameter, - root=pk.root, - position=epoch, - leaf_parts=chain_ends, - opening=sig.path, - ) + + +PROD_SIGNATURE_SCHEME = GeneralizedXmssScheme( + PROD_CONFIG, + PROD_PRF, + PROD_TWEAK_HASHER, + PROD_MERKLE_TREE, + PROD_TARGET_SUM_ENCODER, + PROD_RAND, +) +"""An instance configured for production-level parameters.""" + +TEST_SIGNATURE_SCHEME = GeneralizedXmssScheme( + TEST_CONFIG, + TEST_PRF, + TEST_TWEAK_HASHER, + TEST_MERKLE_TREE, + TEST_TARGET_SUM_ENCODER, + TEST_RAND, +) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/merkle_tree.py b/src/lean_spec/subspecs/xmss/merkle_tree.py index 09db20e4..ab0a26eb 100644 --- a/src/lean_spec/subspecs/xmss/merkle_tree.py +++ b/src/lean_spec/subspecs/xmss/merkle_tree.py @@ -18,6 +18,12 @@ from typing import List +from lean_spec.subspecs.xmss.constants import ( + PROD_CONFIG, + TEST_CONFIG, + XmssConfig, +) + from .structures import ( HashDigest, HashTree, @@ -25,196 +31,214 @@ HashTreeOpening, Parameter, ) -from .tweak_hash import TreeTweak -from .tweak_hash import apply as apply_tweakable_hash -from .utils import rand_domain - - -def _get_padded_layer( - nodes: List[HashDigest], start_index: int -) -> HashTreeLayer: - """ - Pads a layer to ensure its nodes can always be paired up. - - This helper function adds random padding to the start and/or end of a list - of nodes to enforce an invariant: every layer must start at an even index - and end at an odd index. This guarantees that every node has a sibling, - simplifying the construction of the next layer up. - - Args: - nodes: The list of active nodes for the current layer. - start_index: The starting index of the first node in `nodes`. - - Returns: - A new `HashTreeLayer` with padding applied. - """ - nodes_with_padding: List[HashDigest] = [] - end_index = start_index + len(nodes) - 1 - - # Prepend random padding if the layer starts at an odd index. - if start_index % 2 == 1: - nodes_with_padding.append(rand_domain()) - - # The actual start index of the padded layer is always the even - # number at or immediately before the original start_index. - actual_start_index = start_index - (start_index % 2) - - # Add the actual node content. - nodes_with_padding.extend(nodes) - - # Append random padding if the layer ends at an even index. - if end_index % 2 == 0: - nodes_with_padding.append(rand_domain()) - - return HashTreeLayer( - start_index=actual_start_index, nodes=nodes_with_padding - ) - - -def build_tree( - depth: int, - start_index: int, - parameter: Parameter, - leaf_hashes: List[HashDigest], -) -> HashTree: - """ - Builds a new sparse Merkle tree from a list of leaf hashes. - - The construction proceeds bottom-up, from the leaf layer to the root. - At each level, pairs of sibling nodes are hashed to create their parents - for the next level up. - - Args: - depth: The depth of the tree (e.g., 32 for a 2^32 leaf space). - start_index: The index of the first leaf in `leaf_hashes`. - parameter: The public parameter `P` for the hash function. - leaf_hashes: The list of pre-hashed leaf nodes to build the tree on. - - Returns: - The fully constructed `HashTree` object. - """ - # Start with the leaf hashes and apply the initial padding. - layers: List[HashTreeLayer] = [] - current_layer = _get_padded_layer(leaf_hashes, start_index) - layers.append(current_layer) - - # Iterate from the leaf layer (level 0) up to the root. - for level in range(depth): - parents: List[HashDigest] = [] - # Group the current layer's nodes into pairs of siblings. - for i, children in enumerate( - zip( - current_layer.nodes[0::2], - current_layer.nodes[1::2], - strict=False, - ) - ): - # Calculate the position of the parent node in the next level up. - parent_index = (current_layer.start_index // 2) + i - # Create the tweak for hashing these two children. - tweak = TreeTweak(level=level + 1, index=parent_index) - # Hash the left and right children to get their parent. - parent_node = apply_tweakable_hash( - parameter, tweak, list(children) - ) - parents.append(parent_node) - - # Pad the new list of parents to prepare for the next iteration. - new_start_index = current_layer.start_index // 2 - current_layer = _get_padded_layer(parents, new_start_index) +from .tweak_hash import ( + PROD_TWEAK_HASHER, + TEST_TWEAK_HASHER, + TreeTweak, + TweakHasher, +) +from .utils import PROD_RAND, TEST_RAND, Rand + + +class MerkleTree: + """An instance of the Merkle Tree handler for a given config.""" + + def __init__(self, config: XmssConfig, hasher: TweakHasher, rand: Rand): + """Initializes with a config, a hasher, and a random generator.""" + self.config = config + self.hasher = hasher + self.rand = rand + + def _get_padded_layer( + self, nodes: List[HashDigest], start_index: int + ) -> HashTreeLayer: + """ + Pads a layer to ensure its nodes can always be paired up. + + This helper function adds random padding to the start and/or end of a list + of nodes to enforce an invariant: every layer must start at an even index + and end at an odd index. This guarantees that every node has a sibling, + simplifying the construction of the next layer up. + + Args: + nodes: The list of active nodes for the current layer. + start_index: The starting index of the first node in `nodes`. + + Returns: + A new `HashTreeLayer` with padding applied. + """ + nodes_with_padding: List[HashDigest] = [] + end_index = start_index + len(nodes) - 1 + + # Prepend random padding if the layer starts at an odd index. + if start_index % 2 == 1: + nodes_with_padding.append(self.rand.domain()) + + # The actual start index of the padded layer is always the even + # number at or immediately before the original start_index. + actual_start_index = start_index - (start_index % 2) + + # Add the actual node content. + nodes_with_padding.extend(nodes) + + # Append random padding if the layer ends at an even index. + if end_index % 2 == 0: + nodes_with_padding.append(self.rand.domain()) + + return HashTreeLayer( + start_index=actual_start_index, nodes=nodes_with_padding + ) + + def build( + self, + depth: int, + start_index: int, + parameter: Parameter, + leaf_hashes: List[HashDigest], + ) -> HashTree: + """ + Builds a new sparse Merkle tree from a list of leaf hashes. + + The construction proceeds bottom-up, from the leaf layer to the root. + At each level, pairs of sibling nodes are hashed to create their parents + for the next level up. + + Args: + depth: The depth of the tree (e.g., 32 for a 2^32 leaf space). + start_index: The index of the first leaf in `leaf_hashes`. + parameter: The public parameter `P` for the hash function. + leaf_hashes: The list of pre-hashed leaf nodes to build the tree on. + + Returns: + The fully constructed `HashTree` object. + """ + # Start with the leaf hashes and apply the initial padding. + layers: List[HashTreeLayer] = [] + current_layer = self._get_padded_layer(leaf_hashes, start_index) layers.append(current_layer) - return HashTree(depth=depth, layers=layers) - - -def get_root(tree: HashTree) -> HashDigest: - """Extracts the root digest from a constructed Merkle tree.""" - # The root is the single node in the final layer. - return tree.layers[-1].nodes[0] - - -def get_path(tree: HashTree, position: int) -> HashTreeOpening: - """ - Computes the Merkle authentication path for a leaf at a given position. - - The path consists of the list of sibling nodes required to reconstruct the - root, starting from the leaf's sibling and going up the tree. - - Args: - tree: The `HashTree` from which to extract the path. - position: The absolute index of the leaf for which to generate path. - - Returns: - A `HashTreeOpening` object containing the co-path. - """ - co_path: List[HashDigest] = [] - current_position = position - - # Iterate from the bottom layer (level 0) up to the layer below the root. - for level in range(tree.depth): - # Determine the sibling's position using an XOR operation. - sibling_position = current_position ^ 1 - # Find the sibling's index within the sparse `nodes` vector. - layer = tree.layers[level] - sibling_index_in_vec = sibling_position - layer.start_index - # Add the sibling's hash to the co-path. - co_path.append(layer.nodes[sibling_index_in_vec]) - # Move up to the parent's position for the next iteration. - current_position //= 2 - - return HashTreeOpening(siblings=co_path) - - -def verify_path( - parameter: Parameter, - root: HashDigest, - position: int, - leaf_parts: List[HashDigest], - opening: HashTreeOpening, -) -> bool: - """ - Verifies a Merkle authentication path. - - This function reconstructs a candidate root by starting with the leaf node - and repeatedly hashing it with the sibling nodes provided in the opening - path. - - The verification succeeds if the candidate root matches the true root. - - Args: - parameter: The public parameter `P` for the hash function. - root: The known, trusted Merkle root. - position: The absolute index of the leaf being verified. - leaf_parts: The list of digests that constitute the original leaf. - opening: The `HashTreeOpening` object containing the sibling path. - - Returns: - `True` if the path is valid, `False` otherwise. - """ - # The first step is to hash the constituent parts of the leaf to get - # the actual node at layer 0 of the tree. - leaf_tweak = TreeTweak(level=0, index=position) - current_node = apply_tweakable_hash(parameter, leaf_tweak, leaf_parts) - - # Iterate up the tree, hashing the current node with its sibling from - # the path at each level. - current_position = position - for level, sibling_node in enumerate(opening.siblings): - # Determine if the current node is a left or right child. - if current_position % 2 == 0: - # Current node is a left child; sibling is on the right. - children = [current_node, sibling_node] - else: - # Current node is a right child; sibling is on the left. - children = [sibling_node, current_node] - - # Move up to the parent's position for the next iteration. - current_position //= 2 - # Create the tweak for the parent's level and position. - parent_tweak = TreeTweak(level=level + 1, index=current_position) - # Hash the children to compute the parent node. - current_node = apply_tweakable_hash(parameter, parent_tweak, children) - - # After iterating through the entire path, the final computed node - # should be the root of the tree. - return current_node == root + # Iterate from the leaf layer (level 0) up to the root. + for level in range(depth): + parents: List[HashDigest] = [] + # Group the current layer's nodes into pairs of siblings. + for i, children in enumerate( + zip( + current_layer.nodes[0::2], + current_layer.nodes[1::2], + strict=False, + ) + ): + # Calculate the position of the parent node in the next level up. + parent_index = (current_layer.start_index // 2) + i + # Create the tweak for hashing these two children. + tweak = TreeTweak(level=level + 1, index=parent_index) + # Hash the left and right children to get their parent. + parent_node = self.hasher.apply( + parameter, tweak, list(children) + ) + parents.append(parent_node) + + # Pad the new list of parents to prepare for the next iteration. + new_start_index = current_layer.start_index // 2 + current_layer = self._get_padded_layer(parents, new_start_index) + layers.append(current_layer) + + return HashTree(depth=depth, layers=layers) + + def get_root(self, tree: HashTree) -> HashDigest: + """Extracts the root digest from a constructed Merkle tree.""" + # The root is the single node in the final layer. + return tree.layers[-1].nodes[0] + + def get_path(self, tree: HashTree, position: int) -> HashTreeOpening: + """ + Computes the Merkle authentication path for a leaf at a given position. + + The path consists of the list of sibling nodes required to reconstruct the + root, starting from the leaf's sibling and going up the tree. + + Args: + tree: The `HashTree` from which to extract the path. + position: The absolute index of the leaf for which to generate path. + + Returns: + A `HashTreeOpening` object containing the co-path. + """ + co_path: List[HashDigest] = [] + current_position = position + + # Iterate from the bottom layer (level 0) up to the layer below the root. + for level in range(tree.depth): + # Determine the sibling's position using an XOR operation. + sibling_position = current_position ^ 1 + # Find the sibling's index within the sparse `nodes` vector. + layer = tree.layers[level] + sibling_index_in_vec = sibling_position - layer.start_index + # Add the sibling's hash to the co-path. + co_path.append(layer.nodes[sibling_index_in_vec]) + # Move up to the parent's position for the next iteration. + current_position //= 2 + + return HashTreeOpening(siblings=co_path) + + def verify_path( + self, + parameter: Parameter, + root: HashDigest, + position: int, + leaf_parts: List[HashDigest], + opening: HashTreeOpening, + ) -> bool: + """ + Verifies a Merkle authentication path. + + This function reconstructs a candidate root by starting with the leaf node + and repeatedly hashing it with the sibling nodes provided in the opening + path. + + The verification succeeds if the candidate root matches the true root. + + Args: + parameter: The public parameter `P` for the hash function. + root: The known, trusted Merkle root. + position: The absolute index of the leaf being verified. + leaf_parts: The list of digests that constitute the original leaf. + opening: The `HashTreeOpening` object containing the sibling path. + + Returns: + `True` if the path is valid, `False` otherwise. + """ + # The first step is to hash the constituent parts of the leaf to get + # the actual node at layer 0 of the tree. + leaf_tweak = TreeTweak(level=0, index=position) + current_node = self.hasher.apply(parameter, leaf_tweak, leaf_parts) + + # Iterate up the tree, hashing the current node with its sibling from + # the path at each level. + current_position = position + for level, sibling_node in enumerate(opening.siblings): + # Determine if the current node is a left or right child. + if current_position % 2 == 0: + # Current node is a left child; sibling is on the right. + children = [current_node, sibling_node] + else: + # Current node is a right child; sibling is on the left. + children = [sibling_node, current_node] + + # Move up to the parent's position for the next iteration. + current_position //= 2 + # Create the tweak for the parent's level and position. + parent_tweak = TreeTweak(level=level + 1, index=current_position) + # Hash the children to compute the parent node. + current_node = self.hasher.apply(parameter, parent_tweak, children) + + # After iterating through the entire path, the final computed node + # should be the root of the tree. + return current_node == root + + +PROD_MERKLE_TREE = MerkleTree(PROD_CONFIG, PROD_TWEAK_HASHER, PROD_RAND) +"""An instance configured for production-level parameters.""" + +TEST_MERKLE_TREE = MerkleTree(TEST_CONFIG, TEST_TWEAK_HASHER, TEST_RAND) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index 4161ffc9..7005111b 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -15,121 +15,140 @@ from typing import List +from lean_spec.subspecs.xmss.poseidon import ( + PROD_POSEIDON, + TEST_POSEIDON, + PoseidonXmss, +) + from ..koalabear import Fp, P from .constants import ( - BASE, - DIMENSION, - FINAL_LAYER, - MSG_LEN_FE, - POS_INVOCATIONS, - POS_OUTPUT_LEN_PER_INV_FE, - TWEAK_LEN_FE, + PROD_CONFIG, + TEST_CONFIG, TWEAK_PREFIX_MESSAGE, + XmssConfig, ) from .hypercube import find_layer, get_hypercube_part_size, map_to_vertex from .structures import Parameter, Randomness -from .tweak_hash import poseidon_compress - - -def encode_message(message: bytes) -> List[Fp]: - """ - Encodes a 32-byte message into a list of field elements. - - The message bytes are interpreted as a single little-endian integer, - which is then decomposed into its base-`P` representation. - """ - # Interpret the little-endian bytes as a single large integer. - acc = int.from_bytes(message, "little") - - # Decompose the integer into a list of field elements (base-P). - elements: List[Fp] = [] - for _ in range(MSG_LEN_FE): - elements.append(Fp(value=acc)) - acc //= P - return elements - - -def encode_epoch(epoch: int) -> List[Fp]: - """Encodes epoch and domain separator into a list of field elements.""" - # Combine the epoch and the message hash prefix into a single integer. - acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value - - # Decompose the integer into its base-P representation. - elements: List[Fp] = [] - for _ in range(TWEAK_LEN_FE): - elements.append(Fp(value=acc)) - acc //= P - return elements - - -def _map_into_hypercube_part(field_elements: List[Fp]) -> List[int]: - """ - Maps a list of field elements to a vertex in - the top `FINAL_LAYER` layers. - """ - # Combine field elements into one large integer (big-endian, base-P). - acc = 0 - for fe in field_elements: - acc = acc * P + fe.value - - # Reduce this integer modulo the size of the target domain. - # - # The target domain is the set of all vertices in layers 0..FINAL_LAYER. - domain_size = get_hypercube_part_size(BASE, DIMENSION, FINAL_LAYER) - acc %= domain_size - - # Find which layer the resulting index falls into, and its offset. - layer, offset = find_layer(BASE, DIMENSION, acc) - - # Map the offset within the layer to a unique vertex. - return map_to_vertex(BASE, DIMENSION, layer, offset) - - -def apply( - parameter: Parameter, - epoch: int, - randomness: Randomness, - message: bytes, -) -> List[int]: - """ - Applies the full "Top Level" message hash procedure. - - This involves multiple invocations of Poseidon2, with the combined output - mapped into a specific region of the hypercube. - - Args: - parameter: The public parameter `P`. - epoch: The current epoch. - randomness: A random value `rho`. - message: The 32-byte message to be hashed. - - Returns: - A vertex in the hypercube, represented as a list of `DIMENSION` ints. - """ - # Encode the message and epoch as field elements. - message_fe = encode_message(message) - epoch_fe = encode_epoch(epoch) - - # Iteratively call Poseidon2 to generate a long hash output. - poseidon_outputs: List[Fp] = [] - for i in range(POS_INVOCATIONS): - # Use the iteration number as a domain separator for each hash call. - iteration_separator = [Fp(value=i)] - - # The input is: rho || P || epoch || message || iteration. - combined_input = ( - randomness - + parameter - + epoch_fe - + message_fe - + iteration_separator - ) - # Hash the combined input using Poseidon2 compression mode. - iteration_output = poseidon_compress( - combined_input, 24, POS_OUTPUT_LEN_PER_INV_FE - ) - poseidon_outputs.extend(iteration_output) - # Map the final list of field elements into a hypercube vertex. - return _map_into_hypercube_part(poseidon_outputs) +class MessageHasher: + """An instance of the "Top Level" message hasher for a given config.""" + + def __init__(self, config: XmssConfig, poseidon_hasher: PoseidonXmss): + """Initializes the hasher with a specific parameter set.""" + self.config = config + self.poseidon = poseidon_hasher + + def encode_message(self, message: bytes) -> List[Fp]: + """ + Encodes a 32-byte message into a list of field elements. + + The message bytes are interpreted as a single little-endian integer, + which is then decomposed into its base-`P` representation. + """ + # Interpret the little-endian bytes as a single large integer. + acc = int.from_bytes(message, "little") + + # Decompose the integer into a list of field elements (base-P). + elements: List[Fp] = [] + for _ in range(self.config.MSG_LEN_FE): + elements.append(Fp(value=acc)) + acc //= P + return elements + + def encode_epoch(self, epoch: int) -> List[Fp]: + """Encodes epoch and domain separator into a list of field elements.""" + # Combine the epoch and the message hash prefix into a single integer. + acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value + + # Decompose the integer into its base-P representation. + elements: List[Fp] = [] + for _ in range(self.config.TWEAK_LEN_FE): + elements.append(Fp(value=acc)) + acc //= P + return elements + + def _map_into_hypercube_part(self, field_elements: List[Fp]) -> List[int]: + """ + Maps a list of field elements to a vertex in + the top `FINAL_LAYER` layers. + """ + # Get the config for this scheme. + config = self.config + + # Combine field elements into one large integer (big-endian, base-P). + acc = 0 + for fe in field_elements: + acc = acc * P + fe.value + + # Reduce this integer modulo the size of the target domain. + # + # The target domain is the set of all vertices in layers 0..FINAL_LAYER. + domain_size = get_hypercube_part_size( + config.BASE, config.DIMENSION, config.FINAL_LAYER + ) + acc %= domain_size + + # Find which layer the resulting index falls into, and its offset. + layer, offset = find_layer(config.BASE, config.DIMENSION, acc) + + # Map the offset within the layer to a unique vertex. + return map_to_vertex(config.BASE, config.DIMENSION, layer, offset) + + def apply( + self, + parameter: Parameter, + epoch: int, + randomness: Randomness, + message: bytes, + ) -> List[int]: + """ + Applies the full "Top Level" message hash procedure. + + This involves multiple invocations of Poseidon2, with the combined output + mapped into a specific region of the hypercube. + + Args: + parameter: The public parameter `P`. + epoch: The current epoch. + randomness: A random value `rho`. + message: The 32-byte message to be hashed. + + Returns: + A vertex in the hypercube, represented as a list of `DIMENSION` ints. + """ + # Encode the message and epoch as field elements. + message_fe = self.encode_message(message) + epoch_fe = self.encode_epoch(epoch) + + # Iteratively call Poseidon2 to generate a long hash output. + poseidon_outputs: List[Fp] = [] + for i in range(self.config.POS_INVOCATIONS): + # Use the iteration number as a domain separator for each hash call. + iteration_separator = [Fp(value=i)] + + # The input is: rho || P || epoch || message || iteration. + combined_input = ( + randomness + + parameter + + epoch_fe + + message_fe + + iteration_separator + ) + + # Hash the combined input using Poseidon2 compression mode. + iteration_output = self.poseidon.compress( + combined_input, 24, self.config.POS_OUTPUT_LEN_PER_INV_FE + ) + poseidon_outputs.extend(iteration_output) + + # Map the final list of field elements into a hypercube vertex. + return self._map_into_hypercube_part(poseidon_outputs) + + +PROD_MESSAGE_HASHER = MessageHasher(PROD_CONFIG, PROD_POSEIDON) +"""An instance configured for production-level parameters.""" + +TEST_MESSAGE_HASHER = MessageHasher(TEST_CONFIG, TEST_POSEIDON) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py new file mode 100644 index 00000000..4cb9aeb1 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -0,0 +1,136 @@ +"""Defines the Tweakable Hash function using the Poseidon2 permutation.""" + +from __future__ import annotations + +from typing import List + +from ..koalabear import Fp, P +from ..poseidon2.permutation import PARAMS_16, PARAMS_24, permute +from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from .structures import HashDigest + + +class PoseidonXmss: + """An instance of the Poseidon2-based tweakable hash for a given config.""" + + def __init__(self, config: XmssConfig): + """Initializes the hasher with a specific parameter set.""" + self.config = config + + def compress( + self, input_vec: List[Fp], width: int, output_len: int + ) -> HashDigest: + """ + A low-level wrapper for Poseidon2 in compression mode. + + It computes: `Truncate(Permute(padded_input) + padded_input)`. + + Args: + input_vec: The input data to be hashed. + width: The state width of the Poseidon2 permutation (16 or 24). + output_len: The number of field elements in the output digest. + + Returns: + A hash digest of `output_len` field elements. + """ + # Select the correct permutation parameters based on the state width. + params = PARAMS_16 if width == 16 else PARAMS_24 + + # Create a fixed-width buffer and copy the input, padding with zeros. + padded_input = [Fp(value=0)] * width + padded_input[: len(input_vec)] = input_vec + + # Apply the Poseidon2 permutation. + permuted_state = permute(padded_input, params) + + # Apply the feed-forward step, adding the input back element-wise. + final_state = [ + p + i for p, i in zip(permuted_state, padded_input, strict=True) + ] + + # Truncate the state to the desired output length and return. + return final_state[:output_len] + + def safe_domain_separator( + self, lengths: List[int], capacity_len: int + ) -> List[Fp]: + """ + Computes a domain separator for the sponge construction. + + This function hashes a list of length parameters to create a unique + "capacity value" that configures the sponge for a specific hashing task. + + Args: + lengths: A list of integer parameters defining the hash context. + capacity_len: The desired length of the output capacity value. + + Returns: + A list of `capacity_len` field elements. + """ + # Pack the length parameters into a single large integer. + acc = 0 + for length in lengths: + acc = (acc << 32) | length + + # Decompose the integer into a list of 24 field elements for hashing. + # + # 24 is the fixed input width for this specific domain separation hash. + input_vec: List[Fp] = [] + for _ in range(24): + input_vec.append(Fp(value=acc)) + acc //= P + + # Compress the decomposed vector to produce the capacity value. + return self.compress(input_vec, 24, capacity_len) + + def sponge( + self, input_vec: List[Fp], capacity_value: List[Fp], output_len: int + ) -> HashDigest: + """ + A low-level wrapper for Poseidon2 in sponge mode. + + Args: + input_vec: The input data of arbitrary length. + capacity_value: A domain-separating value. + output_len: The number of field elements in the output digest. + + Returns: + A hash digest of `output_len` field elements. + """ + # Use the width-24 permutation for the sponge. + params = PARAMS_24 + width = params.width + rate = width - len(capacity_value) + + # Pad the input vector to be an exact multiple of the rate. + num_extra = (rate - (len(input_vec) % rate)) % rate + padded_input = input_vec + [Fp(value=0)] * num_extra + + # Initialize the state with the capacity value. + state = [Fp(value=0)] * width + state[rate:] = capacity_value + + # Absorb the input in rate-sized chunks. + for i in range(0, len(padded_input), rate): + chunk = padded_input[i : i + rate] + # Add the chunk to the rate part of the state. + for j in range(rate): + state[j] += chunk[j] + # Apply the permutation. + state = permute(state, params) + + # Squeeze the output until enough elements have been generated. + output: HashDigest = [] + while len(output) < output_len: + output.extend(state[:rate]) + state = permute(state, params) + + # Truncate to the final output length and return. + return output[:output_len] + + +PROD_POSEIDON = PoseidonXmss(PROD_CONFIG) +"""An instance configured for production-level parameters.""" + +TEST_POSEIDON = PoseidonXmss(TEST_CONFIG) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index 65fd5178..b26620d2 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -11,10 +11,16 @@ import hashlib import os +from typing import List from lean_spec.subspecs.koalabear import Fp -from lean_spec.subspecs.xmss.constants import HASH_LEN_FE, PRF_KEY_LENGTH -from lean_spec.subspecs.xmss.structures import HashDigest, PRFKey +from lean_spec.subspecs.xmss.constants import ( + PRF_KEY_LENGTH, + PROD_CONFIG, + TEST_CONFIG, + XmssConfig, +) +from lean_spec.subspecs.xmss.structures import PRFKey from lean_spec.types.uint64 import Uint64 PRF_DOMAIN_SEP: bytes = bytes( @@ -55,81 +61,91 @@ """ -PRFOutput = HashDigest -""" -A type alias for the output of the PRF. -It is a list of field elements. -""" +class Prf: + """An instance of the SHAKE128-based PRF for a given config.""" + + def __init__(self, config: XmssConfig): + """Initializes the PRF with a specific parameter set.""" + self.config = config + + def key_gen(self) -> PRFKey: + """ + Generates a cryptographically secure random key for the PRF. + This function sources randomness from the operating system's entropy pool. -def key_gen() -> PRFKey: - """ - Generates a cryptographically secure random key for the PRF. + Returns: + A new, randomly generated PRF key of `PRF_KEY_LENGTH` bytes. + """ + return os.urandom(PRF_KEY_LENGTH) - This function sources randomness from the operating system's entropy pool. + def apply(self, key: PRFKey, epoch: int, chain_index: Uint64) -> List[Fp]: + """ + Applies the PRF to derive the secret values for a specific epoch + and chain. - Returns: - A new, randomly generated PRF key of `PRF_KEY_LENGTH` bytes. - """ - # Use os.urandom for cryptographically secure random bytes. - return os.urandom(PRF_KEY_LENGTH) + The function computes: + `SHAKE128(DOMAIN_SEP || key || epoch || chain_index)` + and interprets the output as a list of field elements. + Args: + key: The secret PRF key. + epoch: The epoch number (a 32-bit unsigned integer). + chain_index: The index of the hash chain (a 64-bit uint). -def apply(key: PRFKey, epoch: int, chain_index: Uint64) -> PRFOutput: - """ - Applies the PRF to derive the secret values for a specific epoch and chain. + Returns: + A list of `DIMENSION` field elements, which are the secret starting + points for the hash chains of the specified epoch. + """ + # Get the config for this scheme. + config = self.config - The function computes `SHAKE128(DOMAIN_SEP || key || epoch || chain_index)` - and interprets the output as a list of field elements. + # Create a new SHAKE128 hash instance. + hasher = hashlib.shake_128() - Args: - key: The secret PRF key. - epoch: The epoch number (a 32-bit unsigned integer). - chain_index: The index of the hash chain (a 64-bit unsigned integer). + # Absorb the domain separator to contextualize the hash. + hasher.update(PRF_DOMAIN_SEP) - Returns: - A list of `DIMENSION` field elements, which are the secret starting - points for the hash chains of the specified epoch. - """ - # Create a new SHAKE128 hash instance. - hasher = hashlib.shake_128() + # Absorb the secret key. + hasher.update(key) - # Absorb the domain separator to contextualize the hash. - hasher.update(PRF_DOMAIN_SEP) + # Absorb the epoch, represented as a 4-byte big-endian integer. + # + # This ensures that each epoch produces a unique set of secrets. + hasher.update(epoch.to_bytes(4, "big")) - # Absorb the secret key. - hasher.update(key) + # Absorb the chain index, as an 8-byte big-endian integer. + # + # This is used to derive a unique start value for each hash chain. + hasher.update(chain_index.to_bytes(8, "big")) - # Absorb the epoch, represented as a 4-byte big-endian integer. - # - # This ensures that each epoch produces a unique set of secrets. - hasher.update(epoch.to_bytes(4, "big")) + # Determine the total number of bytes to extract from the SHAKE output. + # + # For key generation, we need one field element per chain. + num_bytes_to_read = PRF_BYTES_PER_FE * config.HASH_LEN_FE + prf_output_bytes = hasher.digest(num_bytes_to_read) - # Absorb the chain index, as an 8-byte big-endian integer. - # - # This is used to derive a unique start value for each hash chain. - hasher.update(chain_index.to_bytes(8, "big")) + # Convert the byte output into a list of field elements. + output_elements: List[Fp] = [] + for i in range(config.HASH_LEN_FE): + # Extract an 8-byte chunk for the current field element. + start = i * PRF_BYTES_PER_FE + end = start + PRF_BYTES_PER_FE + chunk = prf_output_bytes[start:end] - # Determine the total number of bytes to extract from the SHAKE output. - # - # For key generation, we need one field element per chain. - num_bytes_to_read = PRF_BYTES_PER_FE * HASH_LEN_FE - prf_output_bytes = hasher.digest(num_bytes_to_read) + # Convert the chunk to a large integer. + integer_value = int.from_bytes(chunk, "big") - # Convert the byte output into a list of field elements. - output_elements: PRFOutput = [] - for i in range(HASH_LEN_FE): - # Extract an 8-byte chunk for the current field element. - start = i * PRF_BYTES_PER_FE - end = start + PRF_BYTES_PER_FE - chunk = prf_output_bytes[start:end] + # Reduce the integer modulo the field prime `P`. + # + # The Fp constructor handles the modulo operation automatically. + output_elements.append(Fp(value=integer_value)) - # Convert the chunk to a large integer. - integer_value = int.from_bytes(chunk, "big") + return output_elements - # Reduce the integer modulo the field prime `P`. - # - # The Fp constructor handles the modulo operation automatically. - output_elements.append(Fp(value=integer_value)) - return output_elements +PROD_PRF = Prf(PROD_CONFIG) +"""An instance configured for production-level parameters.""" + +TEST_PRF = Prf(TEST_CONFIG) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/structures.py b/src/lean_spec/subspecs/xmss/structures.py index 78ccc667..b2d532ad 100644 --- a/src/lean_spec/subspecs/xmss/structures.py +++ b/src/lean_spec/subspecs/xmss/structures.py @@ -5,7 +5,7 @@ from pydantic import BaseModel, ConfigDict, Field from ..koalabear import Fp -from .constants import HASH_LEN_FE, PARAMETER_LEN, PRF_KEY_LENGTH, RAND_LEN_FE +from .constants import PRF_KEY_LENGTH PRFKey = Annotated[ bytes, Field(min_length=PRF_KEY_LENGTH, max_length=PRF_KEY_LENGTH) @@ -17,23 +17,17 @@ """ -HashDigest = Annotated[ - List[Fp], Field(min_length=HASH_LEN_FE, max_length=HASH_LEN_FE) -] +HashDigest = List[Fp] """ A type alias representing a hash digest. """ -Parameter = Annotated[ - List[Fp], Field(min_length=PARAMETER_LEN, max_length=PARAMETER_LEN) -] +Parameter = List[Fp] """ A type alias representing the public parameter `P`. """ -Randomness = Annotated[ - List[Fp], Field(min_length=RAND_LEN_FE, max_length=RAND_LEN_FE) -] +Randomness = List[Fp] """ A type alias representing the randomness `rho`. """ @@ -92,10 +86,8 @@ class PublicKey(BaseModel): """The public key for the Generalized XMSS scheme.""" model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - root: List[Fp] = Field(..., max_length=HASH_LEN_FE, min_length=HASH_LEN_FE) - parameter: Parameter = Field( - ..., max_length=PARAMETER_LEN, min_length=PARAMETER_LEN - ) + root: List[Fp] + parameter: Parameter class Signature(BaseModel): @@ -103,9 +95,7 @@ class Signature(BaseModel): model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) path: HashTreeOpening - rho: Randomness = Field( - ..., max_length=RAND_LEN_FE, min_length=RAND_LEN_FE - ) + rho: Randomness hashes: List[HashDigest] diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py index 66251f5b..b03c055a 100644 --- a/src/lean_spec/subspecs/xmss/target_sum.py +++ b/src/lean_spec/subspecs/xmss/target_sum.py @@ -8,42 +8,63 @@ from typing import List, Optional -from .constants import TARGET_SUM -from .message_hash import apply as apply_message_hash +from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from .message_hash import ( + PROD_MESSAGE_HASHER, + TEST_MESSAGE_HASHER, + MessageHasher, +) from .structures import Parameter, Randomness -def encode( - parameter: Parameter, message: bytes, rho: Randomness, epoch: int -) -> Optional[List[int]]: - """ - Encodes a message into a codeword if it meets the target sum criteria. - - This function first uses the message hash to map the input to a vertex in - the hypercube. It then checks if the sum of the vertex's coordinates - matches the scheme's `TARGET_SUM`. This filtering step is the core of the - Target Sum scheme. - - Args: - parameter: The public parameter `P`. - message: The message to encode. - rho: The randomness used for this encoding attempt. - epoch: The current epoch. - - Returns: - The codeword (a list of integers) if the sum is correct, - otherwise `None`. - """ - # Apply the message hash to get a potential codeword (a vertex). - codeword_candidate = apply_message_hash(parameter, epoch, rho, message) - - # Check if the candidate satisfies the target sum condition. - if sum(codeword_candidate) == TARGET_SUM: - # If the sum is correct, this is a valid codeword. - return codeword_candidate - else: - # If the sum does not match, this `rho` is invalid for this message. - # - # The caller (the `sign` function) will need to try again with new - # randomness. - return None +class TargetSumEncoder: + """An instance of the Target Sum encoder for a given config.""" + + def __init__(self, config: XmssConfig, message_hasher: MessageHasher): + """Initializes the encoder with a specific parameter set.""" + self.config = config + self.message_hasher = message_hasher + + def encode( + self, parameter: Parameter, message: bytes, rho: Randomness, epoch: int + ) -> Optional[List[int]]: + """ + Encodes a message into a codeword if it meets the target sum criteria. + + This function first uses the message hash to map the input to a vertex in + the hypercube. It then checks if the sum of the vertex's coordinates + matches the scheme's `TARGET_SUM`. This filtering step is the core of the + Target Sum scheme. + + Args: + parameter: The public parameter `P`. + message: The message to encode. + rho: The randomness used for this encoding attempt. + epoch: The current epoch. + + Returns: + The codeword (a list of integers) if the sum is correct, + otherwise `None`. + """ + # Apply the message hash to get a potential codeword (a vertex). + codeword_candidate = self.message_hasher.apply( + parameter, epoch, rho, message + ) + + # Check if the candidate satisfies the target sum condition. + if sum(codeword_candidate) == self.config.TARGET_SUM: + # If the sum is correct, this is a valid codeword. + return codeword_candidate + else: + # If the sum does not match, this `rho` is invalid for this message. + # + # The caller (the `sign` function) will need to try again with new + # randomness. + return None + + +PROD_TARGET_SUM_ENCODER = TargetSumEncoder(PROD_CONFIG, PROD_MESSAGE_HASHER) +"""An instance configured for production-level parameters.""" + +TEST_TARGET_SUM_ENCODER = TargetSumEncoder(TEST_CONFIG, TEST_MESSAGE_HASHER) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index 24508923..072c5324 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -16,18 +16,20 @@ from pydantic import Field +from lean_spec.subspecs.xmss.poseidon import ( + PROD_POSEIDON, + TEST_POSEIDON, + PoseidonXmss, +) from lean_spec.types.base import StrictBaseModel from ..koalabear import Fp, P -from ..poseidon2.permutation import PARAMS_16, PARAMS_24, permute from .constants import ( - CAPACITY, - DIMENSION, - HASH_LEN_FE, - PARAMETER_LEN, - TWEAK_LEN_FE, + PROD_CONFIG, + TEST_CONFIG, TWEAK_PREFIX_CHAIN, TWEAK_PREFIX_TREE, + XmssConfig, ) from .structures import HashDigest, Parameter @@ -47,251 +49,167 @@ class ChainTweak(StrictBaseModel): step: int = Field(ge=0, description="The step number within the chain.") -Tweak = Union[TreeTweak, ChainTweak] -"""A type alias representing any valid tweak structure.""" - - -def encode_tweak(tweak: Tweak, length: int) -> List[Fp]: - """ - Encodes a tweak structure into a list of field elements for hashing. - - This function packs the tweak's integer components into a single large - integer, then performs a base-`P` decomposition to get field elements. - This ensures a unique and deterministic mapping from any tweak to a format - consumable by the hash function. - - The packing scheme is designed to be injective: - - `TreeTweak`: `(level << 40) | (index << 8) | TWEAK_PREFIX_TREE` - - `ChainTweak`: `(epoch << 24) | (chain_index << 16) - | (step << 8) | TWEAK_PREFIX_CHAIN` - - Args: - tweak: The `TreeTweak` or `ChainTweak` object. - length: The desired number of field elements in the output list. - - Returns: - A list of `length` field elements representing the encoded tweak. - """ - # Pack the tweak's integer fields into a single large integer. - # - # A unique prefix is included for domain separation. - if isinstance(tweak, TreeTweak): - acc = ( - (tweak.level << 40) | (tweak.index << 8) | TWEAK_PREFIX_TREE.value - ) - else: - acc = ( - (tweak.epoch << 24) - | (tweak.chain_index << 16) - | (tweak.step << 8) - | TWEAK_PREFIX_CHAIN.value - ) - - # Decompose the large integer `acc` into a list of field elements. - # This is a standard base-P decomposition. - # - # The number of elements is determined by the `length` parameter. - elements: List[Fp] = [] - for _ in range(length): - elements.append(Fp(value=acc)) - acc //= P - return elements - - -def poseidon_compress( - input_vec: List[Fp], width: int, output_len: int -) -> HashDigest: - """ - A low-level wrapper for Poseidon2 in compression mode. - - It computes: `Truncate(Permute(padded_input) + padded_input)`. - - Args: - input_vec: The input data to be hashed. - width: The state width of the Poseidon2 permutation (16 or 24). - output_len: The number of field elements in the output digest. - - Returns: - A hash digest of `output_len` field elements. - """ - # Select the correct permutation parameters based on the state width. - params = PARAMS_16 if width == 16 else PARAMS_24 - - # Create a fixed-width buffer and copy the input, padding with zeros. - padded_input = [Fp(value=0)] * width - padded_input[: len(input_vec)] = input_vec - - # Apply the Poseidon2 permutation. - permuted_state = permute(padded_input, params) - - # Apply the feed-forward step, adding the input back element-wise. - final_state = [ - p + i for p, i in zip(permuted_state, padded_input, strict=True) - ] - - # Truncate the state to the desired output length and return. - return final_state[:output_len] - - -def _poseidon_safe_domain_separator( - lengths: List[int], capacity_len: int -) -> List[Fp]: - """ - Computes a domain separator for the sponge construction. - - This function hashes a list of length parameters to create a unique - "capacity value" that configures the sponge for a specific hashing task. - - Args: - lengths: A list of integer parameters defining the hash context. - capacity_len: The desired length of the output capacity value. - - Returns: - A list of `capacity_len` field elements. - """ - # Pack the length parameters into a single large integer. - acc = 0 - for length in lengths: - acc = (acc << 32) | length - - # Decompose the integer into a list of 24 field elements for hashing. - # - # 24 is the fixed input width for this specific domain separation hash. - input_vec: List[Fp] = [] - for _ in range(24): - input_vec.append(Fp(value=acc)) - acc //= P - - # Compress the decomposed vector to produce the capacity value. - return poseidon_compress(input_vec, 24, capacity_len) - - -def _poseidon_sponge( - input_vec: List[Fp], capacity_value: List[Fp], output_len: int -) -> HashDigest: - """ - A low-level wrapper for Poseidon2 in sponge mode. - - Args: - input_vec: The input data of arbitrary length. - capacity_value: A domain-separating value. - output_len: The number of field elements in the output digest. - - Returns: - A hash digest of `output_len` field elements. - """ - # Use the width-24 permutation for the sponge. - params = PARAMS_24 - width = params.width - rate = width - len(capacity_value) - - # Pad the input vector to be an exact multiple of the rate. - num_extra = (rate - (len(input_vec) % rate)) % rate - padded_input = input_vec + [Fp(value=0)] * num_extra - - # Initialize the state with the capacity value. - state = [Fp(value=0)] * width - state[rate:] = capacity_value - - # Absorb the input in rate-sized chunks. - for i in range(0, len(padded_input), rate): - chunk = padded_input[i : i + rate] - # Add the chunk to the rate part of the state. - for j in range(rate): - state[j] += chunk[j] - # Apply the permutation. - state = permute(state, params) - - # Squeeze the output until enough elements have been generated. - output: HashDigest = [] - while len(output) < output_len: - output.extend(state[:rate]) - state = permute(state, params) - - # Truncate to the final output length and return. - return output[:output_len] - - -def apply( - parameter: Parameter, tweak: Tweak, message_parts: List[HashDigest] -) -> HashDigest: - """ - Applies the tweakable Poseidon2 hash function to a message. - - This function serves as the main entry point for all hashing operations. - It automatically selects the correct Poseidon2 mode (compression or sponge) - based on the number of message parts provided. - - Args: - parameter: The public parameter `P` for the hash function. - tweak: A `TreeTweak` or `ChainTweak` for domain separation. - message_parts: A list of one or more hash digests to be hashed. - - Returns: - A new hash digest of `HASH_LEN_FE` field elements. - """ - # Encode the tweak structure into field elements. - encoded_tweak = encode_tweak(tweak, TWEAK_LEN_FE) - - if len(message_parts) == 1: - # Case 1: Hashing a single digest (used in hash chains). - # - # We use the efficient width-16 compression mode. - input_vec = parameter + encoded_tweak + message_parts[0] - return poseidon_compress(input_vec, 16, HASH_LEN_FE) +class TweakHasher: + """An instance of the Tweakable Hasher for a given config.""" + + def __init__(self, config: XmssConfig, poseidon_hasher: PoseidonXmss): + """Initializes the hasher with a specific parameter set.""" + self.config = config + self.poseidon = poseidon_hasher + + Tweak = Union[TreeTweak, ChainTweak] + """A type alias representing any valid tweak structure.""" + + def _encode_tweak(self, tweak: Tweak, length: int) -> List[Fp]: + """ + Encodes a tweak structure into a list of field elements for hashing. + + This function packs the tweak's integer components into a single large + integer, then performs a base-`P` decomposition to get field elements. + This ensures a unique and deterministic mapping from any tweak to a format + consumable by the hash function. + + The packing scheme is designed to be injective: + - `TreeTweak`: `(level << 40) | (index << 8) | TWEAK_PREFIX_TREE` + - `ChainTweak`: `(epoch << 24) | (chain_index << 16) + | (step << 8) | TWEAK_PREFIX_CHAIN` + + Args: + tweak: The `TreeTweak` or `ChainTweak` object. + length: The desired number of field elements in the output list. - elif len(message_parts) == 2: - # Case 2: Hashing two digests (used for Merkle tree nodes). + Returns: + A list of `length` field elements representing the encoded tweak. + """ + # Pack the tweak's integer fields into a single large integer. # - # We use the width-24 compression mode. - input_vec = ( - parameter + encoded_tweak + message_parts[0] + message_parts[1] - ) - return poseidon_compress(input_vec, 24, HASH_LEN_FE) - - else: - # Case 3: Hashing many digests (used for the Merkle tree leaf). + # A unique prefix is included for domain separation. + if isinstance(tweak, TreeTweak): + acc = ( + (tweak.level << 40) + | (tweak.index << 8) + | TWEAK_PREFIX_TREE.value + ) + else: + acc = ( + (tweak.epoch << 24) + | (tweak.chain_index << 16) + | (tweak.step << 8) + | TWEAK_PREFIX_CHAIN.value + ) + + # Decompose the large integer `acc` into a list of field elements. + # This is a standard base-P decomposition. # - # We use the robust sponge mode. - # First, flatten the list of message parts into a single vector. - flattened_message = list(chain.from_iterable(message_parts)) - input_vec = parameter + encoded_tweak + flattened_message - - # Create a domain separator based on the dimensions of the input. - lengths = [PARAMETER_LEN, TWEAK_LEN_FE, DIMENSION, HASH_LEN_FE] - capacity_value = _poseidon_safe_domain_separator(lengths, CAPACITY) - - return _poseidon_sponge(input_vec, capacity_value, HASH_LEN_FE) - - -def hash_chain( - parameter: Parameter, - epoch: int, - chain_index: int, - start_step: int, - num_steps: int, - start_digest: HashDigest, -) -> HashDigest: - """ - Performs repeated hashing to traverse a WOTS+ hash chain. - - Args: - parameter: The public parameter `P`. - epoch: The signature epoch. - chain_index: The index of the hash chain. - start_step: The starting step number in the chain. - num_steps: The number of hashing steps to perform. - start_digest: The digest to begin hashing from. - - Returns: - The final hash digest after `num_steps` applications. - """ - current_digest = start_digest - for i in range(num_steps): - # Create a unique tweak for the current position in the chain. - tweak = ChainTweak( - epoch=epoch, chain_index=chain_index, step=start_step + i + 1 - ) - # Apply the hash function. - current_digest = apply(parameter, tweak, [current_digest]) - return current_digest + # The number of elements is determined by the `length` parameter. + elements: List[Fp] = [] + for _ in range(length): + elements.append(Fp(value=acc)) + acc //= P + return elements + + def apply( + self, + parameter: Parameter, + tweak: Tweak, + message_parts: List[HashDigest], + ) -> HashDigest: + """ + Applies the tweakable Poseidon2 hash function to a message. + + This function serves as the main entry point for all hashing operations. + It automatically selects the correct Poseidon2 mode (compression or sponge) + based on the number of message parts provided. + + Args: + parameter: The public parameter `P` for the hash function. + tweak: A `TreeTweak` or `ChainTweak` for domain separation. + message_parts: A list of one or more hash digests to be hashed. + + Returns: + A new hash digest of `HASH_LEN_FE` field elements. + """ + # Get the config for this scheme. + config = self.config + + # Encode the tweak structure into field elements. + encoded_tweak = self._encode_tweak(tweak, config.TWEAK_LEN_FE) + + if len(message_parts) == 1: + # Case 1: Hashing a single digest (used in hash chains). + # + # We use the efficient width-16 compression mode. + input_vec = parameter + encoded_tweak + message_parts[0] + return self.poseidon.compress(input_vec, 16, config.HASH_LEN_FE) + + elif len(message_parts) == 2: + # Case 2: Hashing two digests (used for Merkle tree nodes). + # + # We use the width-24 compression mode. + input_vec = ( + parameter + encoded_tweak + message_parts[0] + message_parts[1] + ) + return self.poseidon.compress(input_vec, 24, config.HASH_LEN_FE) + + else: + # Case 3: Hashing many digests (used for the Merkle tree leaf). + # + # We use the robust sponge mode. + # First, flatten the list of message parts into a single vector. + flattened_message = list(chain.from_iterable(message_parts)) + input_vec = parameter + encoded_tweak + flattened_message + + # Create a domain separator based on the dimensions of the input. + lengths = [ + config.PARAMETER_LEN, + config.TWEAK_LEN_FE, + config.DIMENSION, + config.HASH_LEN_FE, + ] + capacity_value = self.poseidon.safe_domain_separator( + lengths, config.CAPACITY + ) + + return self.poseidon.sponge( + input_vec, capacity_value, config.HASH_LEN_FE + ) + + def hash_chain( + self, + parameter: Parameter, + epoch: int, + chain_index: int, + start_step: int, + num_steps: int, + start_digest: HashDigest, + ) -> HashDigest: + """ + Performs repeated hashing to traverse a WOTS+ hash chain. + + Args: + parameter: The public parameter `P`. + epoch: The signature epoch. + chain_index: The index of the hash chain. + start_step: The starting step number in the chain. + num_steps: The number of hashing steps to perform. + start_digest: The digest to begin hashing from. + + Returns: + The final hash digest after `num_steps` applications. + """ + current_digest = start_digest + for i in range(num_steps): + # Create a unique tweak for the current position in the chain. + tweak = ChainTweak( + epoch=epoch, chain_index=chain_index, step=start_step + i + 1 + ) + # Apply the hash function. + current_digest = self.apply(parameter, tweak, [current_digest]) + return current_digest + + +PROD_TWEAK_HASHER = TweakHasher(PROD_CONFIG, PROD_POSEIDON) +"""An instance configured for production-level parameters.""" + +TEST_TWEAK_HASHER = TweakHasher(TEST_CONFIG, TEST_POSEIDON) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index 9c04efc1..01aacc7b 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -4,26 +4,38 @@ from typing import List from ..koalabear import Fp, P -from .constants import HASH_LEN_FE, PARAMETER_LEN, RAND_LEN_FE +from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig from .structures import HashDigest, Parameter, Randomness -def rand_field_elements(length: int) -> List[Fp]: - """Generates a random list of field elements.""" - # For each element, generate a secure random integer in the range [0, P-1]. - return [Fp(value=secrets.randbelow(P)) for _ in range(length)] +class Rand: + """An instance of the random data generator for a given config.""" + def __init__(self, config: XmssConfig): + """Initializes the generator with a specific parameter set.""" + self.config = config -def rand_parameter() -> Parameter: - """Generates a random public parameter.""" - return rand_field_elements(PARAMETER_LEN) + def field_elements(self, length: int) -> List[Fp]: + """Generates a random list of field elements.""" + # For each element, generate a secure random integer + # in the range [0, P-1]. + return [Fp(value=secrets.randbelow(P)) for _ in range(length)] + def parameter(self) -> Parameter: + """Generates a random public parameter.""" + return self.field_elements(self.config.PARAMETER_LEN) -def rand_domain() -> HashDigest: - """Generates a random hash digest.""" - return rand_field_elements(HASH_LEN_FE) + def domain(self) -> HashDigest: + """Generates a random hash digest.""" + return self.field_elements(self.config.HASH_LEN_FE) + def rho(self) -> Randomness: + """Generates randomness `rho` for message encoding.""" + return self.field_elements(self.config.RAND_LEN_FE) -def rand_rho() -> Randomness: - """Generates randomness `rho` for message encoding.""" - return rand_field_elements(RAND_LEN_FE) + +PROD_RAND = Rand(PROD_CONFIG) +"""An instance configured for production-level parameters.""" + +TEST_RAND = Rand(TEST_CONFIG) +"""A lightweight instance for test environments.""" diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py index 1d332f14..0992922d 100644 --- a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py +++ b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py @@ -9,26 +9,21 @@ import pytest from lean_spec.subspecs.xmss.merkle_tree import ( - build_tree, - get_path, - get_root, - verify_path, + PROD_MERKLE_TREE, + MerkleTree, ) -from lean_spec.subspecs.xmss.structures import HashDigest, Parameter +from lean_spec.subspecs.xmss.structures import HashDigest from lean_spec.subspecs.xmss.tweak_hash import ( TreeTweak, ) -from lean_spec.subspecs.xmss.tweak_hash import ( - apply as apply_tweakable_hash, -) -from lean_spec.subspecs.xmss.utils import ( - rand_domain, - rand_parameter, -) def _run_commit_open_verify_roundtrip( - num_leaves: int, depth: int, start_index: int, leaf_parts_len: int + merkle_tree: MerkleTree, + num_leaves: int, + depth: int, + start_index: int, + leaf_parts_len: int, ) -> None: """ A helper function to perform a full Merkle tree roundtrip test. @@ -46,15 +41,15 @@ def _run_commit_open_verify_roundtrip( leaf_parts_len: The number of digests that constitute a single leaf. """ # SETUP: Generate a random parameter and the raw leaf data. - parameter: Parameter = rand_parameter() + parameter = merkle_tree.rand.parameter() leaves: list[list[HashDigest]] = [ - [rand_domain() for _ in range(leaf_parts_len)] + [merkle_tree.rand.domain() for _ in range(leaf_parts_len)] for _ in range(num_leaves) ] # HASH LEAVES: Compute the layer 0 nodes by hashing the leaf parts. leaf_hashes: list[HashDigest] = [ - apply_tweakable_hash( + merkle_tree.hasher.apply( parameter, TreeTweak(level=0, index=start_index + i), leaf_parts, @@ -63,14 +58,16 @@ def _run_commit_open_verify_roundtrip( ] # COMMIT: Build the Merkle tree from the leaf hashes. - tree = build_tree(depth, start_index, parameter, leaf_hashes) - root = get_root(tree) + tree = merkle_tree.build(depth, start_index, parameter, leaf_hashes) + root = merkle_tree.get_root(tree) # OPEN & VERIFY: For each leaf, generate and verify its path. for i, leaf_parts in enumerate(leaves): position = start_index + i - opening = get_path(tree, position) - is_valid = verify_path(parameter, root, position, leaf_parts, opening) + opening = merkle_tree.get_path(tree, position) + is_valid = merkle_tree.verify_path( + parameter, root, position, leaf_parts, opening + ) assert is_valid, f"Verification failed for leaf at position {position}" @@ -102,5 +99,5 @@ def test_commit_open_verify_roundtrip( assert start_index + num_leaves <= (1 << depth) _run_commit_open_verify_roundtrip( - num_leaves, depth, start_index, leaf_parts_len + PROD_MERKLE_TREE, num_leaves, depth, start_index, leaf_parts_len ) diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py index 6029fd3b..7a7c1c39 100644 --- a/tests/lean_spec/subspecs/xmss/test_message_hash.py +++ b/tests/lean_spec/subspecs/xmss/test_message_hash.py @@ -6,60 +6,59 @@ from lean_spec.subspecs.koalabear import Fp, P from lean_spec.subspecs.xmss.constants import ( - BASE, - DIMENSION, - FINAL_LAYER, - MESSAGE_LENGTH, - MSG_LEN_FE, - TWEAK_LEN_FE, + TEST_CONFIG, TWEAK_PREFIX_MESSAGE, ) from lean_spec.subspecs.xmss.message_hash import ( - apply, - encode_epoch, - encode_message, + TEST_MESSAGE_HASHER, ) -from lean_spec.subspecs.xmss.utils import rand_parameter, rand_rho +from lean_spec.subspecs.xmss.utils import TEST_RAND def test_encode_message() -> None: """Tests `encode_message` with various message patterns.""" + config = TEST_CONFIG + hasher = TEST_MESSAGE_HASHER + # All-zero message - msg_zeros = b"\x00" * MESSAGE_LENGTH - encoded_zeros = encode_message(msg_zeros) - assert len(encoded_zeros) == MSG_LEN_FE + msg_zeros = b"\x00" * config.MESSAGE_LENGTH + encoded_zeros = hasher.encode_message(msg_zeros) + assert len(encoded_zeros) == config.MSG_LEN_FE assert all(fe.value == 0 for fe in encoded_zeros) # All-max message (0xff) - msg_max = b"\xff" * MESSAGE_LENGTH + msg_max = b"\xff" * config.MESSAGE_LENGTH acc = int.from_bytes(msg_max, "little") expected_max: List[Fp] = [] - for _ in range(MSG_LEN_FE): + for _ in range(config.MSG_LEN_FE): expected_max.append(Fp(value=acc)) acc //= P - assert encode_message(msg_max) == expected_max + assert hasher.encode_message(msg_max) == expected_max def test_encode_epoch() -> None: """ Tests `encode_epoch` for correctness and injectivity. """ + hasher = TEST_MESSAGE_HASHER + config = TEST_CONFIG + # Test specific values from the Rust reference tests. test_epochs = [0, 42, 2**32 - 1] for epoch in test_epochs: acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value expected: List[Fp] = [] - for _ in range(TWEAK_LEN_FE): + for _ in range(config.TWEAK_LEN_FE): expected.append(Fp(value=acc)) acc //= P - assert encode_epoch(epoch) == expected + assert hasher.encode_epoch(epoch) == expected # Test for injectivity. It is highly unlikely for a collision to occur # with a few random samples if the encoding is injective. num_trials = 1000 seen_encodings: set[tuple[Fp, ...]] = set() for i in range(num_trials): - encoding = tuple(encode_epoch(i)) + encoding = tuple(hasher.encode_epoch(i)) assert encoding not in seen_encodings seen_encodings.add(encoding) @@ -69,21 +68,25 @@ def test_apply_output_is_in_correct_hypercube_part() -> None: Tests that the output of `apply` is a valid vertex that lies within the top `FINAL_LAYER` layers of the hypercube. """ + config = TEST_CONFIG + hasher = TEST_MESSAGE_HASHER + rand = TEST_RAND + # Setup with random inputs. - parameter = rand_parameter() + parameter = rand.parameter() epoch = 313 - randomness = rand_rho() - message = b"\xaa" * MESSAGE_LENGTH + randomness = rand.rho() + message = b"\xaa" * config.MESSAGE_LENGTH # Call the message hash function. - vertex = apply(parameter, epoch, randomness, message) + vertex = hasher.apply(parameter, epoch, randomness, message) # Verify the properties of the output vertex. # # The length of the vertex must be equal to the hypercube's dimension. - assert len(vertex) == DIMENSION + assert len(vertex) == config.DIMENSION # Each coordinate must be smaller than the base `w`. - assert all(0 <= coord < BASE for coord in vertex) + assert all(0 <= coord < config.BASE for coord in vertex) # Check that the vertex lies in the correct set of layers. # @@ -93,6 +96,8 @@ def test_apply_output_is_in_correct_hypercube_part() -> None: # # This is equivalent to `sum(coords) >= v*(w-1) - FINAL_LAYER`. coord_sum = sum(vertex) - min_required_sum = (BASE - 1) * DIMENSION - FINAL_LAYER + min_required_sum = ( + config.BASE - 1 + ) * config.DIMENSION - config.FINAL_LAYER assert coord_sum >= min_required_sum, "Vertex is not in the top layers" diff --git a/tests/lean_spec/subspecs/xmss/test_prf.py b/tests/lean_spec/subspecs/xmss/test_prf.py index 236d8cdf..924a715d 100644 --- a/tests/lean_spec/subspecs/xmss/test_prf.py +++ b/tests/lean_spec/subspecs/xmss/test_prf.py @@ -2,8 +2,11 @@ Tests for the SHAKE128-based pseudorandom function (PRF). """ -from lean_spec.subspecs.xmss.constants import HASH_LEN_FE, PRF_KEY_LENGTH -from lean_spec.subspecs.xmss.prf import apply, key_gen +from lean_spec.subspecs.xmss.constants import ( + PRF_KEY_LENGTH, + TEST_CONFIG, +) +from lean_spec.subspecs.xmss.prf import TEST_PRF def test_key_gen_is_random() -> None: @@ -13,15 +16,17 @@ def test_key_gen_is_random() -> None: This test mirrors the logic from the reference Rust implementation. """ + prf = TEST_PRF + # Check that the key has the correct length. - key = key_gen() + key = prf.key_gen() assert len(key) == PRF_KEY_LENGTH # Generate multiple keys and ensure they are not all identical. # # This is a basic check to ensure we are getting fresh randomness. num_trials = 10 - keys = {key_gen() for _ in range(num_trials)} + keys = {prf.key_gen() for _ in range(num_trials)} assert len(keys) == num_trials # Check that the keys are not filled with a single repeated byte. @@ -30,7 +35,7 @@ def test_key_gen_is_random() -> None: # such a key, so this is a good health check. all_same_count = 0 for _ in range(num_trials): - key = key_gen() + key = prf.key_gen() # A set will have size 1 if all elements are the same. if len(set(key)) == 1: all_same_count += 1 @@ -44,24 +49,27 @@ def test_apply_is_sensitive_to_inputs() -> None: This confirms that all parts of the input (key, epoch, chain_index) are being correctly absorbed by the hash function. """ + prf = TEST_PRF + config = TEST_CONFIG + # Generate a baseline output with a set of initial inputs. key1 = b"\x11" * PRF_KEY_LENGTH epoch1 = 10 chain_index1 = 20 - baseline_output = apply(key1, epoch1, chain_index1) - assert len(baseline_output) == HASH_LEN_FE + baseline_output = prf.apply(key1, epoch1, chain_index1) + assert len(baseline_output) == config.HASH_LEN_FE # Test sensitivity to the key. key2 = b"\x22" * PRF_KEY_LENGTH - output_key_changed = apply(key2, epoch1, chain_index1) + output_key_changed = prf.apply(key2, epoch1, chain_index1) assert baseline_output != output_key_changed # Test sensitivity to the epoch. epoch2 = 11 - output_epoch_changed = apply(key1, epoch2, chain_index1) + output_epoch_changed = prf.apply(key1, epoch2, chain_index1) assert baseline_output != output_epoch_changed # Test sensitivity to the chain_index. chain_index2 = 21 - output_index_changed = apply(key1, epoch1, chain_index2) + output_index_changed = prf.apply(key1, epoch1, chain_index2) assert baseline_output != output_index_changed From 9d279760eaf813ba7413bd81bcaf091fc637ced4 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 18:02:41 +0200 Subject: [PATCH 08/13] add end to end tests --- .../lean_spec/subspecs/xmss/test_interface.py | 85 +++++++++++++++++++ 1 file changed, 85 insertions(+) create mode 100644 tests/lean_spec/subspecs/xmss/test_interface.py diff --git a/tests/lean_spec/subspecs/xmss/test_interface.py b/tests/lean_spec/subspecs/xmss/test_interface.py new file mode 100644 index 00000000..80987eb8 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_interface.py @@ -0,0 +1,85 @@ +""" +End-to-end tests for the Generalized XMSS signature scheme. +""" + +import pytest + +from lean_spec.subspecs.xmss.interface import ( + TEST_SIGNATURE_SCHEME, + GeneralizedXmssScheme, +) + +# ================================================================= +# Test Helper Function +# ================================================================= + + +def _test_correctness_roundtrip( + scheme: GeneralizedXmssScheme, + activation_epoch: int, + num_active_epochs: int, +) -> None: + """ + A helper to perform a full key_gen -> sign -> verify roundtrip. + + It generates a key pair, signs a message at a specific epoch, and + verifies the signature. It also checks that verification fails for + an incorrect message or epoch. + """ + # KEY GENERATION + # + # Generate a new key pair for the specified active range. + pk, sk = scheme.key_gen(activation_epoch, num_active_epochs) + + # SIGN & VERIFY + # + # Pick a sample epoch within the active range to test signing. + test_epoch = activation_epoch + num_active_epochs // 2 + message = b"\x42" * scheme.config.MESSAGE_LENGTH + + # Sign the message at the chosen epoch. + # + # This might take a moment as it may try multiple `rho` values. + signature = scheme.sign(sk, test_epoch, message) + + # Verification of the valid signature must succeed. + is_valid = scheme.verify(pk, test_epoch, message, signature) + assert is_valid, "Verification of a valid signature failed" + + # TEST INVALID CASES + # + # Verification must fail if the message is tampered with. + tampered_message = b"\x43" * scheme.config.MESSAGE_LENGTH + is_invalid_msg = scheme.verify(pk, test_epoch, tampered_message, signature) + assert not is_invalid_msg, "Verification succeeded for a tampered message" + + # Verification must fail if the epoch is incorrect. + if num_active_epochs > 1: + wrong_epoch = test_epoch + 1 + is_invalid_epoch = scheme.verify(pk, wrong_epoch, message, signature) + assert not is_invalid_epoch, ( + "Verification succeeded for an incorrect epoch" + ) + + +@pytest.mark.parametrize( + "activation_epoch, num_active_epochs, description", + [ + (10, 4, "Standard case with a short, active lifetime"), + (0, 8, "Lifetime starting at epoch 0"), + (20, 1, "Lifetime with only a single active epoch"), + (7, 5, "Lifetime starting at an odd-numbered epoch"), + ], +) +def test_signature_scheme_correctness( + activation_epoch: int, num_active_epochs: int, description: str +) -> None: + """ + Runs an end-to-end test of the signature scheme using the lightweight + TEST_SCHEME configuration across various lifetime scenarios. + """ + _test_correctness_roundtrip( + scheme=TEST_SIGNATURE_SCHEME, + activation_epoch=activation_epoch, + num_active_epochs=num_active_epochs, + ) From 472d2c29b4ea7cba7df900dbb4996c4cf6988b4f Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 18:11:24 +0200 Subject: [PATCH 09/13] fix linter --- src/lean_spec/subspecs/xmss/constants.py | 14 +++-- src/lean_spec/subspecs/xmss/interface.py | 62 +++++++++++---------- src/lean_spec/subspecs/xmss/merkle_tree.py | 34 ++++++----- src/lean_spec/subspecs/xmss/message_hash.py | 13 +++-- src/lean_spec/subspecs/xmss/poseidon.py | 3 +- src/lean_spec/subspecs/xmss/prf.py | 3 +- src/lean_spec/subspecs/xmss/target_sum.py | 11 ++-- src/lean_spec/subspecs/xmss/tweak_hash.py | 10 ++-- 8 files changed, 85 insertions(+), 65 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index b847c4f6..dec65017 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -1,8 +1,12 @@ """ -Defines the cryptographic constants and configuration presets for the XMSS spec. +Defines the cryptographic constants and configuration presets for the +XMSS spec. This specification corresponds to the "hashing-optimized" Top Level Target Sum -instantiation from the canonical Rust implementation. +instantiation from the canonical Rust implementation +(production instantiation). + +We also provide a test instantiation for testing purposes. .. note:: This specification uses the **KoalaBear** prime field, which is consistent @@ -35,7 +39,7 @@ class XmssConfig(BaseModel): """The base-2 logarithm of the scheme's maximum lifetime.""" @property - def LIFETIME(self) -> int: + def LIFETIME(self) -> int: # noqa: N802 """ The maximum number of epochs supported by this configuration. @@ -51,7 +55,7 @@ def LIFETIME(self) -> int: """The alphabet size for the digits of the encoded message.""" FINAL_LAYER: int - """The number of top layers of the hypercube to map the hash output into.""" + """Number of top layers of the hypercube to map the hash output into.""" TARGET_SUM: int """The required sum of all codeword chunks for a signature to be valid.""" @@ -95,7 +99,7 @@ def LIFETIME(self) -> int: """Number of invocations for the message hash.""" @property - def POS_OUTPUT_LEN_FE(self) -> int: + def POS_OUTPUT_LEN_FE(self) -> int: # noqa: N802 """Total output length for the message hash.""" return self.POS_OUTPUT_LEN_PER_INV_FE * self.POS_INVOCATIONS diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 2ad456ea..190ae126 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -38,7 +38,7 @@ class GeneralizedXmssScheme: - """An instance of the Generalized XMSS signature scheme for a given config.""" + """Instance of the Generalized XMSS signature scheme for a given config.""" def __init__( self, @@ -61,7 +61,9 @@ def key_gen( self, activation_epoch: int, num_active_epochs: int ) -> Tuple[PublicKey, SecretKey]: """ - Generates a new cryptographic key pair. This is a **randomized** algorithm. + Generates a new cryptographic key pair. + + This is a **randomized** algorithm. This function executes the full key generation process: 1. Generates a master secret (PRF key) and a public hash parameter. @@ -72,8 +74,8 @@ def key_gen( 4. The list of all chain endpoints for an epoch forms the one-time public key. This one-time public key is hashed to create a single Merkle leaf. - 5. A Merkle tree is built over all generated leaves, and its root becomes - part of the final public key. + 5. A Merkle tree is built over all generated leaves, and its root + becomes part of the final public key. Args: activation_epoch: The starting epoch for which this key is active. @@ -143,18 +145,20 @@ def key_gen( def sign(self, sk: SecretKey, epoch: int, message: bytes) -> Signature: """ - Produces a digital signature for a given message at a specific epoch. This - is a **randomized** algorithm. + Produces a digital signature for a given message at a specific epoch. + + This is a **randomized** algorithm. - **CRITICAL**: This function must never be called twice with the same secret - key and epoch for different messages, as this would compromise security. + **CRITICAL**: This function must never be called twice with the same + secret key and epoch for different messages, as this + would compromise security. The signing process involves: - 1. Repeatedly attempting to encode the message with fresh randomness - (`rho`) until a valid codeword is found. - 2. Computing the one-time signature, which consists of intermediate - values from the secret hash chains, determined by the digits of - the codeword. + 1. Repeatedly attempting to encode the message with + fresh randomness (`rho`) until a valid codeword is found. + 2. Computing the one-time signature, which consists of + intermediate values from the secret hash chains, + determined by the digits of the codeword. 3. Retrieving the Merkle authentication path for the given epoch. For the formal specification of this process, please refer to: @@ -216,32 +220,34 @@ def verify( self, pk: PublicKey, epoch: int, message: bytes, sig: Signature ) -> bool: r""" - Verifies a digital signature against a public key, message, and epoch. This - is a **deterministic** algorithm. + Verifies a digital signature against a public key, message, and epoch. - This function is a placeholder. The complete verification logic is detailed - below and will be implemented in a future update. + This is a **deterministic** algorithm. ### Verification Algorithm - 1. **Re-encode Message**: The verifier uses the randomness `rho` from the - signature to re-compute the codeword $x = (x_1, \dots, x_v)$ from the - message `m`. + 1. **Re-encode Message**: The verifier uses the randomness `rho` + from the signature to re-compute the codeword + $x = (x_1, \dots, x_v)$ from the message `m`. This includes calculating the checksum or checking the target sum. - 2. **Reconstruct One-Time Public Key**: For each intermediate hash $y_i$ - in the signature, the verifier completes the corresponding hash chain. + 2. **Reconstruct One-Time Public Key**: For each intermediate + hash $y_i$ in the signature, the verifier completes the + corresponding hash chain. + Since $y_i$ was computed with $x_i$ steps, the verifier applies the hash function an additional $w - 1 - x_i$ times to arrive at the one-time public key component $pk_{ep,i}$. - 3. **Compute Merkle Leaf**: The verifier hashes the reconstructed one-time - public key components to compute the expected Merkle leaf for `epoch`. + 3. **Compute Merkle Leaf**: The verifier hashes the reconstructed + one-time public key components to compute the expected Merkle + leaf for `epoch`. - 4. **Verify Merkle Path**: The verifier uses the `path` from the signature - to compute a candidate Merkle root starting from the computed leaf. - Verification succeeds if and only if this candidate root matches the - `root` in the `PublicKey`. + 4. **Verify Merkle Path**: The verifier uses the `path` from + the signature to compute a candidate Merkle root starting from + the computed leaf. + Verification succeeds if and only if this candidate root matches + the `root` in the `PublicKey`. Args: pk: The public key to verify against. diff --git a/src/lean_spec/subspecs/xmss/merkle_tree.py b/src/lean_spec/subspecs/xmss/merkle_tree.py index ab0a26eb..0c93b237 100644 --- a/src/lean_spec/subspecs/xmss/merkle_tree.py +++ b/src/lean_spec/subspecs/xmss/merkle_tree.py @@ -55,10 +55,10 @@ def _get_padded_layer( """ Pads a layer to ensure its nodes can always be paired up. - This helper function adds random padding to the start and/or end of a list - of nodes to enforce an invariant: every layer must start at an even index - and end at an odd index. This guarantees that every node has a sibling, - simplifying the construction of the next layer up. + This helper function adds random padding to the start and/or end of + a list of nodes to enforce an invariant: every layer must start at an + even index and end at an odd index. This guarantees that every node has + a sibling, simplifying the construction of the next layer up. Args: nodes: The list of active nodes for the current layer. @@ -100,14 +100,15 @@ def build( Builds a new sparse Merkle tree from a list of leaf hashes. The construction proceeds bottom-up, from the leaf layer to the root. - At each level, pairs of sibling nodes are hashed to create their parents - for the next level up. + At each level, pairs of sibling nodes are hashed to create + their parents for the next level up. Args: depth: The depth of the tree (e.g., 32 for a 2^32 leaf space). start_index: The index of the first leaf in `leaf_hashes`. parameter: The public parameter `P` for the hash function. - leaf_hashes: The list of pre-hashed leaf nodes to build the tree on. + leaf_hashes: The list of pre-hashed leaf nodes to + build the tree on. Returns: The fully constructed `HashTree` object. @@ -128,7 +129,8 @@ def build( strict=False, ) ): - # Calculate the position of the parent node in the next level up. + # Calculate the position of the parent node in the + # next level up. parent_index = (current_layer.start_index // 2) + i # Create the tweak for hashing these two children. tweak = TreeTweak(level=level + 1, index=parent_index) @@ -154,12 +156,13 @@ def get_path(self, tree: HashTree, position: int) -> HashTreeOpening: """ Computes the Merkle authentication path for a leaf at a given position. - The path consists of the list of sibling nodes required to reconstruct the - root, starting from the leaf's sibling and going up the tree. + The path consists of the list of sibling nodes required to reconstruct + the root, starting from the leaf's sibling and going up the tree. Args: tree: The `HashTree` from which to extract the path. - position: The absolute index of the leaf for which to generate path. + position: The absolute index of the leaf for which to + generate path. Returns: A `HashTreeOpening` object containing the co-path. @@ -167,7 +170,8 @@ def get_path(self, tree: HashTree, position: int) -> HashTreeOpening: co_path: List[HashDigest] = [] current_position = position - # Iterate from the bottom layer (level 0) up to the layer below the root. + # Iterate from the bottom layer (level 0) up to the layer + # below the root. for level in range(tree.depth): # Determine the sibling's position using an XOR operation. sibling_position = current_position ^ 1 @@ -192,9 +196,9 @@ def verify_path( """ Verifies a Merkle authentication path. - This function reconstructs a candidate root by starting with the leaf node - and repeatedly hashing it with the sibling nodes provided in the opening - path. + This function reconstructs a candidate root by starting with the leaf + node and repeatedly hashing it with the sibling nodes provided + in the opening path. The verification succeeds if the candidate root matches the true root. diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index 7005111b..1266aaac 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -84,7 +84,8 @@ def _map_into_hypercube_part(self, field_elements: List[Fp]) -> List[int]: # Reduce this integer modulo the size of the target domain. # - # The target domain is the set of all vertices in layers 0..FINAL_LAYER. + # The target domain is the set of all vertices + # in layers 0..FINAL_LAYER. domain_size = get_hypercube_part_size( config.BASE, config.DIMENSION, config.FINAL_LAYER ) @@ -106,8 +107,8 @@ def apply( """ Applies the full "Top Level" message hash procedure. - This involves multiple invocations of Poseidon2, with the combined output - mapped into a specific region of the hypercube. + This involves multiple invocations of Poseidon2, with the + combined output mapped into a specific region of the hypercube. Args: parameter: The public parameter `P`. @@ -116,7 +117,8 @@ def apply( message: The 32-byte message to be hashed. Returns: - A vertex in the hypercube, represented as a list of `DIMENSION` ints. + A vertex in the hypercube, represented as a list + of `DIMENSION` ints. """ # Encode the message and epoch as field elements. message_fe = self.encode_message(message) @@ -125,7 +127,8 @@ def apply( # Iteratively call Poseidon2 to generate a long hash output. poseidon_outputs: List[Fp] = [] for i in range(self.config.POS_INVOCATIONS): - # Use the iteration number as a domain separator for each hash call. + # Use the iteration number as a domain separator for + # each hash call. iteration_separator = [Fp(value=i)] # The input is: rho || P || epoch || message || iteration. diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 4cb9aeb1..0c1ef4fb 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -58,7 +58,8 @@ def safe_domain_separator( Computes a domain separator for the sponge construction. This function hashes a list of length parameters to create a unique - "capacity value" that configures the sponge for a specific hashing task. + "capacity value" that configures the sponge for a + specific hashing task. Args: lengths: A list of integer parameters defining the hash context. diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index b26620d2..2b0a3e58 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -72,7 +72,8 @@ def key_gen(self) -> PRFKey: """ Generates a cryptographically secure random key for the PRF. - This function sources randomness from the operating system's entropy pool. + This function sources randomness from the operating system's + entropy pool. Returns: A new, randomly generated PRF key of `PRF_KEY_LENGTH` bytes. diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py index b03c055a..71e7bc82 100644 --- a/src/lean_spec/subspecs/xmss/target_sum.py +++ b/src/lean_spec/subspecs/xmss/target_sum.py @@ -31,10 +31,10 @@ def encode( """ Encodes a message into a codeword if it meets the target sum criteria. - This function first uses the message hash to map the input to a vertex in - the hypercube. It then checks if the sum of the vertex's coordinates - matches the scheme's `TARGET_SUM`. This filtering step is the core of the - Target Sum scheme. + This function first uses the message hash to map the input to a vertex + in the hypercube. It then checks if the sum of the vertex's coordinates + matches the scheme's `TARGET_SUM`. This filtering step is the core of + the Target Sum scheme. Args: parameter: The public parameter `P`. @@ -56,7 +56,8 @@ def encode( # If the sum is correct, this is a valid codeword. return codeword_candidate else: - # If the sum does not match, this `rho` is invalid for this message. + # If the sum does not match, this `rho` is invalid for + # this message. # # The caller (the `sign` function) will need to try again with new # randomness. diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index 072c5324..d1cb170b 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -66,8 +66,8 @@ def _encode_tweak(self, tweak: Tweak, length: int) -> List[Fp]: This function packs the tweak's integer components into a single large integer, then performs a base-`P` decomposition to get field elements. - This ensures a unique and deterministic mapping from any tweak to a format - consumable by the hash function. + This ensures a unique and deterministic mapping from any tweak to a + format consumable by the hash function. The packing scheme is designed to be injective: - `TreeTweak`: `(level << 40) | (index << 8) | TWEAK_PREFIX_TREE` @@ -117,9 +117,9 @@ def apply( """ Applies the tweakable Poseidon2 hash function to a message. - This function serves as the main entry point for all hashing operations. - It automatically selects the correct Poseidon2 mode (compression or sponge) - based on the number of message parts provided. + This function is the main entry point for all hashing operations. + It automatically selects the correct Poseidon2 mode + (compression or sponge) based on the number of message parts provided. Args: parameter: The public parameter `P` for the hash function. From fd5dd02fa3958e8c567be979269e28c0c54df3e0 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 18:19:27 +0200 Subject: [PATCH 10/13] add util for base p decomposition --- src/lean_spec/subspecs/xmss/message_hash.py | 13 +++-------- src/lean_spec/subspecs/xmss/poseidon.py | 9 ++++---- src/lean_spec/subspecs/xmss/tweak_hash.py | 9 +++----- src/lean_spec/subspecs/xmss/utils.py | 22 +++++++++++++++++++ .../subspecs/xmss/test_message_hash.py | 14 +++--------- 5 files changed, 35 insertions(+), 32 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index 1266aaac..e1f916f5 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -20,6 +20,7 @@ TEST_POSEIDON, PoseidonXmss, ) +from lean_spec.subspecs.xmss.utils import int_to_base_p from ..koalabear import Fp, P from .constants import ( @@ -51,11 +52,7 @@ def encode_message(self, message: bytes) -> List[Fp]: acc = int.from_bytes(message, "little") # Decompose the integer into a list of field elements (base-P). - elements: List[Fp] = [] - for _ in range(self.config.MSG_LEN_FE): - elements.append(Fp(value=acc)) - acc //= P - return elements + return int_to_base_p(acc, self.config.MSG_LEN_FE) def encode_epoch(self, epoch: int) -> List[Fp]: """Encodes epoch and domain separator into a list of field elements.""" @@ -63,11 +60,7 @@ def encode_epoch(self, epoch: int) -> List[Fp]: acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value # Decompose the integer into its base-P representation. - elements: List[Fp] = [] - for _ in range(self.config.TWEAK_LEN_FE): - elements.append(Fp(value=acc)) - acc //= P - return elements + return int_to_base_p(acc, self.config.TWEAK_LEN_FE) def _map_into_hypercube_part(self, field_elements: List[Fp]) -> List[int]: """ diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 0c1ef4fb..900120ed 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -4,7 +4,9 @@ from typing import List -from ..koalabear import Fp, P +from lean_spec.subspecs.xmss.utils import int_to_base_p + +from ..koalabear import Fp from ..poseidon2.permutation import PARAMS_16, PARAMS_24, permute from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig from .structures import HashDigest @@ -76,10 +78,7 @@ def safe_domain_separator( # Decompose the integer into a list of 24 field elements for hashing. # # 24 is the fixed input width for this specific domain separation hash. - input_vec: List[Fp] = [] - for _ in range(24): - input_vec.append(Fp(value=acc)) - acc //= P + input_vec = int_to_base_p(acc, 24) # Compress the decomposed vector to produce the capacity value. return self.compress(input_vec, 24, capacity_len) diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index d1cb170b..c82b3870 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -21,9 +21,10 @@ TEST_POSEIDON, PoseidonXmss, ) +from lean_spec.subspecs.xmss.utils import int_to_base_p from lean_spec.types.base import StrictBaseModel -from ..koalabear import Fp, P +from ..koalabear import Fp from .constants import ( PROD_CONFIG, TEST_CONFIG, @@ -102,11 +103,7 @@ def _encode_tweak(self, tweak: Tweak, length: int) -> List[Fp]: # This is a standard base-P decomposition. # # The number of elements is determined by the `length` parameter. - elements: List[Fp] = [] - for _ in range(length): - elements.append(Fp(value=acc)) - acc //= P - return elements + return int_to_base_p(acc, length) def apply( self, diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index 01aacc7b..4de82756 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -39,3 +39,25 @@ def rho(self) -> Randomness: TEST_RAND = Rand(TEST_CONFIG) """A lightweight instance for test environments.""" + + +def int_to_base_p(value: int, num_limbs: int) -> List[Fp]: + """ + Decomposes a large integer into a list of base-P field elements. + + This function performs a standard base conversion, where each "digit" + is an element in the prime field F_p. + + Args: + value: The integer to decompose. + num_limbs: The desired number of output field elements (limbs). + + Returns: + A list of `num_limbs` field elements representing the integer. + """ + limbs: List[Fp] = [] + acc = value + for _ in range(num_limbs): + limbs.append(Fp(value=acc % P)) + acc //= P + return limbs diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py index 7a7c1c39..f895c0ed 100644 --- a/tests/lean_spec/subspecs/xmss/test_message_hash.py +++ b/tests/lean_spec/subspecs/xmss/test_message_hash.py @@ -2,8 +2,6 @@ Tests for the "Top Level" message hashing and encoding logic. """ -from typing import List - from lean_spec.subspecs.koalabear import Fp, P from lean_spec.subspecs.xmss.constants import ( TEST_CONFIG, @@ -12,7 +10,7 @@ from lean_spec.subspecs.xmss.message_hash import ( TEST_MESSAGE_HASHER, ) -from lean_spec.subspecs.xmss.utils import TEST_RAND +from lean_spec.subspecs.xmss.utils import TEST_RAND, int_to_base_p def test_encode_message() -> None: @@ -29,10 +27,7 @@ def test_encode_message() -> None: # All-max message (0xff) msg_max = b"\xff" * config.MESSAGE_LENGTH acc = int.from_bytes(msg_max, "little") - expected_max: List[Fp] = [] - for _ in range(config.MSG_LEN_FE): - expected_max.append(Fp(value=acc)) - acc //= P + expected_max = int_to_base_p(acc, config.MSG_LEN_FE) assert hasher.encode_message(msg_max) == expected_max @@ -47,10 +42,7 @@ def test_encode_epoch() -> None: test_epochs = [0, 42, 2**32 - 1] for epoch in test_epochs: acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value - expected: List[Fp] = [] - for _ in range(config.TWEAK_LEN_FE): - expected.append(Fp(value=acc)) - acc //= P + expected = int_to_base_p(acc, config.TWEAK_LEN_FE) assert hasher.encode_epoch(epoch) == expected # Test for injectivity. It is highly unlikely for a collision to occur From c8ca9bc1c92b6187c589b9361ffb4707a569381b Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 18:25:18 +0200 Subject: [PATCH 11/13] change poseidon xmss --- src/lean_spec/subspecs/xmss/poseidon.py | 28 +++++++++++++++---------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 900120ed..37728cdb 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -7,17 +7,22 @@ from lean_spec.subspecs.xmss.utils import int_to_base_p from ..koalabear import Fp -from ..poseidon2.permutation import PARAMS_16, PARAMS_24, permute -from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from ..poseidon2.permutation import ( + PARAMS_16, + PARAMS_24, + Poseidon2Params, + permute, +) from .structures import HashDigest class PoseidonXmss: """An instance of the Poseidon2-based tweakable hash for a given config.""" - def __init__(self, config: XmssConfig): - """Initializes the hasher with a specific parameter set.""" - self.config = config + def __init__(self, params16: Poseidon2Params, params24: Poseidon2Params): + """Initializes the hasher with specific Poseidon2 permutations.""" + self.params16 = params16 + self.params24 = params24 def compress( self, input_vec: List[Fp], width: int, output_len: int @@ -36,7 +41,8 @@ def compress( A hash digest of `output_len` field elements. """ # Select the correct permutation parameters based on the state width. - params = PARAMS_16 if width == 16 else PARAMS_24 + assert width in (16, 24), "Width must be 16 or 24" + params = self.params16 if width == 16 else self.params24 # Create a fixed-width buffer and copy the input, padding with zeros. padded_input = [Fp(value=0)] * width @@ -98,7 +104,7 @@ def sponge( A hash digest of `output_len` field elements. """ # Use the width-24 permutation for the sponge. - params = PARAMS_24 + params = self.params24 width = params.width rate = width - len(capacity_value) @@ -129,8 +135,8 @@ def sponge( return output[:output_len] -PROD_POSEIDON = PoseidonXmss(PROD_CONFIG) -"""An instance configured for production-level parameters.""" +# An instance configured for production-level parameters. +PROD_POSEIDON = PoseidonXmss(PARAMS_16, PARAMS_24) -TEST_POSEIDON = PoseidonXmss(TEST_CONFIG) -"""A lightweight instance for test environments.""" +# A lightweight instance for test environments. +TEST_POSEIDON = PoseidonXmss(PARAMS_16, PARAMS_24) From 60d9bacce9261d05b6b064d080475d2362d05df2 Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 18:26:55 +0200 Subject: [PATCH 12/13] export configs --- src/lean_spec/subspecs/xmss/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lean_spec/subspecs/xmss/__init__.py b/src/lean_spec/subspecs/xmss/__init__.py index 2ea7484a..33500ddf 100644 --- a/src/lean_spec/subspecs/xmss/__init__.py +++ b/src/lean_spec/subspecs/xmss/__init__.py @@ -5,6 +5,7 @@ It exposes the core data structures and the main interface functions. """ +from .constants import PROD_CONFIG, TEST_CONFIG from .interface import GeneralizedXmssScheme from .structures import ( HashTree, @@ -21,4 +22,6 @@ "SecretKey", "HashTreeOpening", "HashTree", + "PROD_CONFIG", + "TEST_CONFIG", ] From ecc6863399309ccc1750f4fb7826836584c97eed Mon Sep 17 00:00:00 2001 From: Thomas Coratger Date: Sat, 30 Aug 2025 23:23:02 +0200 Subject: [PATCH 13/13] proofreading and some adjustments --- src/lean_spec/subspecs/xmss/__init__.py | 4 +-- src/lean_spec/subspecs/xmss/constants.py | 8 ----- .../xmss/{structures.py => containers.py} | 2 +- src/lean_spec/subspecs/xmss/hypercube.py | 27 ++++++++------ src/lean_spec/subspecs/xmss/interface.py | 27 +++++++++++--- src/lean_spec/subspecs/xmss/merkle_tree.py | 32 +++++++++++++++-- src/lean_spec/subspecs/xmss/message_hash.py | 14 +++++--- src/lean_spec/subspecs/xmss/poseidon.py | 29 +++++++++++---- src/lean_spec/subspecs/xmss/prf.py | 2 +- src/lean_spec/subspecs/xmss/target_sum.py | 5 ++- src/lean_spec/subspecs/xmss/tweak_hash.py | 15 ++------ src/lean_spec/subspecs/xmss/utils.py | 2 +- .../lean_spec/subspecs/xmss/test_hypercube.py | 35 ++++++++----------- .../lean_spec/subspecs/xmss/test_interface.py | 9 +---- .../subspecs/xmss/test_merkle_tree.py | 19 +++------- .../subspecs/xmss/test_message_hash.py | 2 +- tests/lean_spec/subspecs/xmss/test_prf.py | 4 +-- 17 files changed, 132 insertions(+), 104 deletions(-) rename src/lean_spec/subspecs/xmss/{structures.py => containers.py} (97%) diff --git a/src/lean_spec/subspecs/xmss/__init__.py b/src/lean_spec/subspecs/xmss/__init__.py index 33500ddf..0303880e 100644 --- a/src/lean_spec/subspecs/xmss/__init__.py +++ b/src/lean_spec/subspecs/xmss/__init__.py @@ -6,14 +6,14 @@ """ from .constants import PROD_CONFIG, TEST_CONFIG -from .interface import GeneralizedXmssScheme -from .structures import ( +from .containers import ( HashTree, HashTreeOpening, PublicKey, SecretKey, Signature, ) +from .interface import GeneralizedXmssScheme __all__ = [ "GeneralizedXmssScheme", diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index dec65017..2aeb6d36 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -18,8 +18,6 @@ specification in the future. """ -from typing import TypeVar - from pydantic import BaseModel, ConfigDict from typing_extensions import Final @@ -47,7 +45,6 @@ def LIFETIME(self) -> int: # noqa: N802 """ return 1 << self.LOG_LIFETIME - # --- Target Sum WOTS Parameters --- DIMENSION: int """The total number of hash chains, `v`.""" @@ -68,8 +65,6 @@ def LIFETIME(self) -> int: # noqa: N802 Should probably be modified in production. """ - # --- Hash and Encoding Length Parameters (in field elements) --- - PARAMETER_LEN: int """ The length of the public parameter `P`. @@ -104,9 +99,6 @@ def POS_OUTPUT_LEN_FE(self) -> int: # noqa: N802 return self.POS_OUTPUT_LEN_PER_INV_FE * self.POS_INVOCATIONS -Config = TypeVar("Config", bound=XmssConfig) -"""A type variable representing any XmssConfig instance.""" - PROD_CONFIG: Final = XmssConfig( MESSAGE_LENGTH=32, LOG_LIFETIME=32, diff --git a/src/lean_spec/subspecs/xmss/structures.py b/src/lean_spec/subspecs/xmss/containers.py similarity index 97% rename from src/lean_spec/subspecs/xmss/structures.py rename to src/lean_spec/subspecs/xmss/containers.py index b2d532ad..6a648dd7 100644 --- a/src/lean_spec/subspecs/xmss/structures.py +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -1,4 +1,4 @@ -"""Defines the data structures for the Generalized XMSS signature scheme.""" +"""Defines the data containers for the Generalized XMSS signature scheme.""" from typing import Annotated, List diff --git a/src/lean_spec/subspecs/xmss/hypercube.py b/src/lean_spec/subspecs/xmss/hypercube.py index 42450857..c76521f2 100644 --- a/src/lean_spec/subspecs/xmss/hypercube.py +++ b/src/lean_spec/subspecs/xmss/hypercube.py @@ -7,7 +7,7 @@ to layer `d`, where `d` is its distance from the sink vertex `(w-1, ..., w-1)`. The core functionalities are: -1. **Precomputation and Caching**: Efficiently calculates and caches the sizes +1. **Precomputation and Caching**: Computes and caches the sizes of each layer for different hypercube configurations (`w` and `v`). 2. **Mapping**: Provides bijective mappings between an integer index within a layer and the unique vertex (a list of coordinates) it represents. @@ -60,15 +60,11 @@ def prepare_layer_info(w: int) -> List[LayerInfo]: """ Precomputes and caches the number of vertices in each layer of a hypercube. - This function is a crucial precomputation step for the mapping algorithms - used in the "Top Level Target Sum" encoding. It calculates the size of every layer for hypercubes with a given base `w` (where coordinates are in `[0, w-1]`) for all dimensions `v` up to `MAX_DIMENSION`. - The calculation is inductive: the layer sizes for a `v`-dimensional - hypercube are efficiently derived from the already-computed sizes - of a `(v-1)`-dimensional hypercube, based on the recurrence relation from + This precomputation is based on the recurrence relation from Lemma 8 of the paper "At the top of the hypercube" (eprint 2025/889). Args: @@ -87,6 +83,7 @@ def prepare_layer_info(w: int) -> List[LayerInfo]: # # For a 1-dimensional hypercube (v=1), which is just a line of `w` points # with coordinates [0], [1], ..., [w-1]. + # # The distance `d` from the sink `[w-1]` is simply `(w-1) - coordinate`. # Each of the `w` possible layers contains exactly one vertex. @@ -96,8 +93,6 @@ def prepare_layer_info(w: int) -> List[LayerInfo]: # Store the result for v=1, which will seed the inductive step. all_info[1] = LayerInfo(sizes=dim1_sizes, prefix_sums=dim1_prefix_sums) - # INDUCTIVE STEP - # # Now, build the layer info for all higher dimensions up to the maximum. for v in range(2, MAX_DIMENSION + 1): # The maximum possible distance `d` in a v-dimensional hypercube. @@ -118,17 +113,19 @@ def prepare_layer_info(w: int) -> List[LayerInfo]: # Translate the sum over `j` to an index range `k`, # where k = d - j. + # # This allows for an efficient lookup using prefix sums. k_min = d - j_max k_max = d - j_min - # Efficiently calculate the sum using the precomputed prefix sums + # Calculate the sum using the precomputed prefix sums # from the previous dimension's `LayerInfo`. layer_size = prev_layer_info.sizes_sum_in_range(k_min, k_max) current_sizes.append(layer_size) # After computing all layer sizes for dimension `v`, we compute their # prefix sums. + # # This is needed for the *next* iteration (for dimension v+1). current_prefix_sums: List[int] = [] current_sum = 0 @@ -150,12 +147,12 @@ def get_layer_size(w: int, v: int, d: int) -> int: return prepare_layer_info(w)[v].sizes[d] -def get_hypercube_part_size(w: int, v: int, d: int) -> int: +def hypercube_part_size(w: int, v: int, d: int) -> int: """Returns the total size of layers 0 to `d` (inclusive).""" return prepare_layer_info(w)[v].prefix_sums[d] -def find_layer(w: int, v: int, x: int) -> Tuple[int, int]: +def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]: """ Given a global index `x`, finds the layer `d` it belongs to and its local index (`remainder`) within that layer. @@ -173,8 +170,16 @@ def find_layer(w: int, v: int, x: int) -> Tuple[int, int]: d = bisect.bisect_left(prefix_sums, x + 1) if d == 0: + # `x` is in the very first layer (d=0). + # + # The remainder is `x` itself, as the cumulative size of + # preceding layers is zero. remainder = x else: + # The cumulative size of all layers up to `d-1` is + # at `prefix_sums[d - 1]`. + # + # The remainder is `x` minus this cumulative size. remainder = x - prefix_sums[d - 1] return d, remainder diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index 190ae126..06107db4 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -21,13 +21,13 @@ TEST_CONFIG, XmssConfig, ) +from .containers import HashDigest, PublicKey, SecretKey, Signature from .merkle_tree import ( PROD_MERKLE_TREE, TEST_MERKLE_TREE, MerkleTree, ) from .prf import PROD_PRF, TEST_PRF, Prf -from .structures import HashDigest, PublicKey, SecretKey, Signature from .tweak_hash import ( PROD_TWEAK_HASHER, TEST_TWEAK_HASHER, @@ -130,7 +130,7 @@ def key_gen( tree = self.merkle_tree.build( config.LOG_LIFETIME, activation_epoch, parameter, leaf_hashes ) - root = self.merkle_tree.get_root(tree) + root = self.merkle_tree.root(tree) # Assemble and return the public and secret keys. pk = PublicKey(root=root, parameter=parameter) @@ -180,10 +180,13 @@ def sign(self, sk: SecretKey, epoch: int, message: bytes) -> Signature: codeword = None rho = None for _ in range(config.MAX_TRIES): + # Sample a randomness `rho` and try to encode the message. current_rho = self.rand.rho() current_codeword = self.encoder.encode( sk.parameter, message, current_rho, epoch ) + + # If a valid codeword is found, break out of the loop. if current_codeword is not None: codeword = current_codeword rho = current_rho @@ -193,6 +196,12 @@ def sign(self, sk: SecretKey, epoch: int, message: bytes) -> Signature: if codeword is None or rho is None: raise RuntimeError("Failed to find a valid message encoding.") + # Sanity check: the encoding must have the correct number of chunks. + if len(codeword) != self.config.DIMENSION: + raise RuntimeError( + "Encoding is broken: returned too many or too few chunks." + ) + # Compute the one-time signature hashes based on the codeword. ots_hashes: List[HashDigest] = [] for chain_index, steps in enumerate(codeword): @@ -211,7 +220,7 @@ def sign(self, sk: SecretKey, epoch: int, message: bytes) -> Signature: ots_hashes.append(ots_digest) # Get the Merkle authentication path for the current epoch. - path = self.merkle_tree.get_path(sk.tree, epoch) + path = self.merkle_tree.path(sk.tree, epoch) # Assemble and return the final signature. return Signature(path=path, rho=rho, hashes=ots_hashes) @@ -229,6 +238,7 @@ def verify( 1. **Re-encode Message**: The verifier uses the randomness `rho` from the signature to re-compute the codeword $x = (x_1, \dots, x_v)$ from the message `m`. + This includes calculating the checksum or checking the target sum. 2. **Reconstruct One-Time Public Key**: For each intermediate @@ -246,6 +256,7 @@ def verify( 4. **Verify Merkle Path**: The verifier uses the `path` from the signature to compute a candidate Merkle root starting from the computed leaf. + Verification succeeds if and only if this candidate root matches the `root` in the `PublicKey`. @@ -266,6 +277,10 @@ def verify( # Get the config for this scheme. config = self.config + # Check that the signature is for a valid epoch. + if epoch > self.config.LIFETIME: + raise ValueError("The signature is for a future epoch.") + # Re-encode the message using the randomness `rho` from the signature. # # If the encoding is invalid, the signature is invalid. @@ -291,8 +306,10 @@ def verify( ) chain_ends.append(end_digest) - # Verify the Merkle path. This function internally hashes `chain_ends` - # to get the leaf node and then climbs the tree to recompute the root. + # Verify the Merkle path. + # + # This function internally hashes `chain_ends` to get the leaf node + # and then climbs the tree to recompute the root. return self.merkle_tree.verify_path( parameter=pk.parameter, root=pk.root, diff --git a/src/lean_spec/subspecs/xmss/merkle_tree.py b/src/lean_spec/subspecs/xmss/merkle_tree.py index 0c93b237..9b49bf07 100644 --- a/src/lean_spec/subspecs/xmss/merkle_tree.py +++ b/src/lean_spec/subspecs/xmss/merkle_tree.py @@ -24,7 +24,7 @@ XmssConfig, ) -from .structures import ( +from .containers import ( HashDigest, HashTree, HashTreeLayer, @@ -113,6 +113,10 @@ def build( Returns: The fully constructed `HashTree` object. """ + # Check there is enough space for the leafs in the tree. + if start_index + len(leaf_hashes) > 2**depth: + raise ValueError("Not enough space for leafs in the tree.") + # Start with the leaf hashes and apply the initial padding. layers: List[HashTreeLayer] = [] current_layer = self._get_padded_layer(leaf_hashes, start_index) @@ -147,12 +151,12 @@ def build( return HashTree(depth=depth, layers=layers) - def get_root(self, tree: HashTree) -> HashDigest: + def root(self, tree: HashTree) -> HashDigest: """Extracts the root digest from a constructed Merkle tree.""" # The root is the single node in the final layer. return tree.layers[-1].nodes[0] - def get_path(self, tree: HashTree, position: int) -> HashTreeOpening: + def path(self, tree: HashTree, position: int) -> HashTreeOpening: """ Computes the Merkle authentication path for a leaf at a given position. @@ -167,6 +171,17 @@ def get_path(self, tree: HashTree, position: int) -> HashTreeOpening: Returns: A `HashTreeOpening` object containing the co-path. """ + # Check that there is at least one layer in the tree. + if len(tree.layers) == 0: + raise ValueError("Cannot generate path for empty tree.") + + # Check that the position is within the tree's range. + if position < tree.layers[0].start_index: + raise ValueError("Position (before start) is invalid.") + + if position >= tree.layers[0].start_index + len(tree.layers[0].nodes): + raise ValueError("Position (after end) is invalid.") + co_path: List[HashDigest] = [] current_position = position @@ -212,6 +227,17 @@ def verify_path( Returns: `True` if the path is valid, `False` otherwise. """ + # Compute the depth + depth = len(opening.siblings) + # Compute the number of leafs in the tree + num_leafs = 2**depth + # Check that the tree depth is at most 32. + if len(opening.siblings) > 32: + raise ValueError("Tree depth must be at most 32.") + # Check that the position and path length match. + if position >= num_leafs: + raise ValueError("Position and path length do not match.") + # The first step is to hash the constituent parts of the leaf to get # the actual node at layer 0 of the tree. leaf_tweak = TreeTweak(level=0, index=position) diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py index e1f916f5..fde28537 100644 --- a/src/lean_spec/subspecs/xmss/message_hash.py +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -29,8 +29,12 @@ TWEAK_PREFIX_MESSAGE, XmssConfig, ) -from .hypercube import find_layer, get_hypercube_part_size, map_to_vertex -from .structures import Parameter, Randomness +from .containers import Parameter, Randomness +from .hypercube import ( + hypercube_find_layer, + hypercube_part_size, + map_to_vertex, +) class MessageHasher: @@ -79,13 +83,15 @@ def _map_into_hypercube_part(self, field_elements: List[Fp]) -> List[int]: # # The target domain is the set of all vertices # in layers 0..FINAL_LAYER. - domain_size = get_hypercube_part_size( + domain_size = hypercube_part_size( config.BASE, config.DIMENSION, config.FINAL_LAYER ) acc %= domain_size # Find which layer the resulting index falls into, and its offset. - layer, offset = find_layer(config.BASE, config.DIMENSION, acc) + layer, offset = hypercube_find_layer( + config.BASE, config.DIMENSION, acc + ) # Map the offset within the layer to a unique vertex. return map_to_vertex(config.BASE, config.DIMENSION, layer, offset) diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py index 37728cdb..7febf7cc 100644 --- a/src/lean_spec/subspecs/xmss/poseidon.py +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -1,4 +1,4 @@ -"""Defines the Tweakable Hash function using the Poseidon2 permutation.""" +"""Defines the Poseidon2 hash functions for the Generalized XMSS scheme.""" from __future__ import annotations @@ -13,7 +13,7 @@ Poseidon2Params, permute, ) -from .structures import HashDigest +from .containers import HashDigest class PoseidonXmss: @@ -40,6 +40,12 @@ def compress( Returns: A hash digest of `output_len` field elements. """ + # Check that the input vector is long enough to produce the output. + if len(input_vec) < output_len: + raise ValueError( + "Input vector is too short for requested output length." + ) + # Select the correct permutation parameters based on the state width. assert width in (16, 24), "Width must be 16 or 24" params = self.params16 if width == 16 else self.params24 @@ -90,7 +96,11 @@ def safe_domain_separator( return self.compress(input_vec, 24, capacity_len) def sponge( - self, input_vec: List[Fp], capacity_value: List[Fp], output_len: int + self, + input_vec: List[Fp], + capacity_value: List[Fp], + output_len: int, + width: int, ) -> HashDigest: """ A low-level wrapper for Poseidon2 in sponge mode. @@ -99,13 +109,20 @@ def sponge( input_vec: The input data of arbitrary length. capacity_value: A domain-separating value. output_len: The number of field elements in the output digest. + width: The width of the Poseidon2 permutation. Returns: A hash digest of `output_len` field elements. """ - # Use the width-24 permutation for the sponge. - params = self.params24 - width = params.width + # Ensure that the capacity value is not too long. + if len(capacity_value) >= width: + raise ValueError( + "Capacity length must be smaller than the state width." + ) + + # Get the correct permutation parameters. + assert width in (16, 24), "Width must be 16 or 24" + params = self.params16 if width == 16 else self.params24 rate = width - len(capacity_value) # Pad the input vector to be an exact multiple of the rate. diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py index 2b0a3e58..b57910af 100644 --- a/src/lean_spec/subspecs/xmss/prf.py +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -20,7 +20,7 @@ TEST_CONFIG, XmssConfig, ) -from lean_spec.subspecs.xmss.structures import PRFKey +from lean_spec.subspecs.xmss.containers import PRFKey from lean_spec.types.uint64 import Uint64 PRF_DOMAIN_SEP: bytes = bytes( diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py index 71e7bc82..4507f106 100644 --- a/src/lean_spec/subspecs/xmss/target_sum.py +++ b/src/lean_spec/subspecs/xmss/target_sum.py @@ -9,12 +9,12 @@ from typing import List, Optional from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from .containers import Parameter, Randomness from .message_hash import ( PROD_MESSAGE_HASHER, TEST_MESSAGE_HASHER, MessageHasher, ) -from .structures import Parameter, Randomness class TargetSumEncoder: @@ -33,8 +33,7 @@ def encode( This function first uses the message hash to map the input to a vertex in the hypercube. It then checks if the sum of the vertex's coordinates - matches the scheme's `TARGET_SUM`. This filtering step is the core of - the Target Sum scheme. + matches the scheme's `TARGET_SUM`. Args: parameter: The public parameter `P`. diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py index c82b3870..298b77af 100644 --- a/src/lean_spec/subspecs/xmss/tweak_hash.py +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -1,13 +1,4 @@ -""" -Defines the Tweakable Hash function using Poseidon2. - -This module implements the core hashing logic for the XMSS scheme, including: -1. **Tweak Encoding**: Domain-separating different hash usages (chains/trees). -2. **Poseidon2 Compression**: Hashing fixed-size inputs. -3. **Poseidon2 Sponge**: Hashing variable-length inputs (e.g., leaf nodes). -4. **A unified `apply` function** that dispatches to the correct mode. -5. **A `chain` utility** to perform repeated hashing for WOTS chains. -""" +"""Defines the Tweakable Hash function using Poseidon2.""" from __future__ import annotations @@ -32,7 +23,7 @@ TWEAK_PREFIX_TREE, XmssConfig, ) -from .structures import HashDigest, Parameter +from .containers import HashDigest, Parameter class TreeTweak(StrictBaseModel): @@ -168,7 +159,7 @@ def apply( ) return self.poseidon.sponge( - input_vec, capacity_value, config.HASH_LEN_FE + input_vec, capacity_value, config.HASH_LEN_FE, 24 ) def hash_chain( diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py index 4de82756..109e03e3 100644 --- a/src/lean_spec/subspecs/xmss/utils.py +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -5,7 +5,7 @@ from ..koalabear import Fp, P from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig -from .structures import HashDigest, Parameter, Randomness +from .containers import HashDigest, Parameter, Randomness class Rand: diff --git a/tests/lean_spec/subspecs/xmss/test_hypercube.py b/tests/lean_spec/subspecs/xmss/test_hypercube.py index 612ee858..dd49bf93 100644 --- a/tests/lean_spec/subspecs/xmss/test_hypercube.py +++ b/tests/lean_spec/subspecs/xmss/test_hypercube.py @@ -1,12 +1,4 @@ -""" -Tests for the hypercube mathematical operations. - -This module provides extensive tests for the hypercube logic, ensuring that -the precomputation of layer sizes is correct and that the mappings between -integer indices and hypercube vertices are bijective and accurate. - -The tests are designed to be an exact equivalent of the Rust reference tests. -""" +"""Tests for the hypercube mathematical operations.""" import math from functools import lru_cache @@ -16,9 +8,9 @@ from lean_spec.subspecs.xmss.hypercube import ( MAX_DIMENSION, - find_layer, - get_hypercube_part_size, get_layer_size, + hypercube_find_layer, + hypercube_part_size, map_to_vertex, prepare_layer_info, ) @@ -27,6 +19,7 @@ def map_to_integer(w: int, v: int, d: int, a: List[int]) -> int: """ Maps a vertex `a` in layer `d` back to its integer index. + This is a direct translation of the reference Rust implementation. """ if len(a) != v: @@ -144,39 +137,39 @@ def test_get_hypercube_part_size( w: int, v: int, d: int, expected_size: int ) -> None: """ - Tests `get_hypercube_part_size` with known values from the Rust tests. + Tests `hypercube_part_size` with known values from the Rust tests. """ - assert get_hypercube_part_size(w, v, d) == expected_size + assert hypercube_part_size(w, v, d) == expected_size def test_find_layer_boundaries() -> None: """ - Tests `find_layer` with specific boundary-crossing values. + Tests `hypercube_find_layer` with specific boundary-crossing values. """ w, v = 3, 2 # Layer sizes for (w=3, v=2) are [1, 2, 3, 2, 1] # Prefix sums are [1, 3, 6, 8, 9] # x=0 is the 1st element, which is in layer 0. Remainder is 0. - assert find_layer(w, v, 0) == (0, 0) + assert hypercube_find_layer(w, v, 0) == (0, 0) # x=1 is the 2nd element, which is the 1st element in layer 1. # Remainder is 0. - assert find_layer(w, v, 1) == (1, 0) + assert hypercube_find_layer(w, v, 1) == (1, 0) # x=2 is the 3rd element, which is the 2nd element in layer 1. # Remainder is 1. - assert find_layer(w, v, 2) == (1, 1) + assert hypercube_find_layer(w, v, 2) == (1, 1) # x=3 is the 4th element, which is the 1st element in layer 2. # Remainder is 0. - assert find_layer(w, v, 3) == (2, 0) + assert hypercube_find_layer(w, v, 3) == (2, 0) # x=5 is the 6th element, which is the third (last) element in layer 3. # Remainder is 2. - assert find_layer(w, v, 5) == (2, 2) + assert hypercube_find_layer(w, v, 5) == (2, 2) # x=6 is the 7th element, which is the first element in layer 3. # Remainder is 0. - assert find_layer(w, v, 6) == (3, 0) + assert hypercube_find_layer(w, v, 6) == (3, 0) # x=8 is the 9th element, which is the 1st element in layer 4. # Remainder is 0. - assert find_layer(w, v, 8) == (4, 0) + assert hypercube_find_layer(w, v, 8) == (4, 0) def test_map_to_vertex_roundtrip() -> None: diff --git a/tests/lean_spec/subspecs/xmss/test_interface.py b/tests/lean_spec/subspecs/xmss/test_interface.py index 80987eb8..5d7ea12d 100644 --- a/tests/lean_spec/subspecs/xmss/test_interface.py +++ b/tests/lean_spec/subspecs/xmss/test_interface.py @@ -9,10 +9,6 @@ GeneralizedXmssScheme, ) -# ================================================================= -# Test Helper Function -# ================================================================= - def _test_correctness_roundtrip( scheme: GeneralizedXmssScheme, @@ -74,10 +70,7 @@ def _test_correctness_roundtrip( def test_signature_scheme_correctness( activation_epoch: int, num_active_epochs: int, description: str ) -> None: - """ - Runs an end-to-end test of the signature scheme using the lightweight - TEST_SCHEME configuration across various lifetime scenarios. - """ + """Runs an end-to-end test of the signature scheme.""" _test_correctness_roundtrip( scheme=TEST_SIGNATURE_SCHEME, activation_epoch=activation_epoch, diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py index 0992922d..a0597253 100644 --- a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py +++ b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py @@ -1,18 +1,12 @@ -""" -Tests for the sparse Merkle tree implementation. - -This module verifies the correctness of the Merkle tree construction, -path generation, and verification logic by performing a full -"commit-open-verify" roundtrip across various tree configurations. -""" +"""Tests for the sparse Merkle tree implementation.""" import pytest +from lean_spec.subspecs.xmss.containers import HashDigest from lean_spec.subspecs.xmss.merkle_tree import ( PROD_MERKLE_TREE, MerkleTree, ) -from lean_spec.subspecs.xmss.structures import HashDigest from lean_spec.subspecs.xmss.tweak_hash import ( TreeTweak, ) @@ -59,12 +53,12 @@ def _run_commit_open_verify_roundtrip( # COMMIT: Build the Merkle tree from the leaf hashes. tree = merkle_tree.build(depth, start_index, parameter, leaf_hashes) - root = merkle_tree.get_root(tree) + root = merkle_tree.root(tree) # OPEN & VERIFY: For each leaf, generate and verify its path. for i, leaf_parts in enumerate(leaves): position = start_index + i - opening = merkle_tree.get_path(tree, position) + opening = merkle_tree.path(tree, position) is_valid = merkle_tree.verify_path( parameter, root, position, leaf_parts, opening ) @@ -91,10 +85,7 @@ def test_commit_open_verify_roundtrip( leaf_parts_len: int, description: str, ) -> None: - """ - Tests the Merkle tree logic for various configurations using test-specific - tree depths for efficiency. - """ + """Tests the Merkle tree logic for various configurations.""" # Ensure the test case parameters are valid for the specified tree depth. assert start_index + num_leaves <= (1 << depth) diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py index f895c0ed..18476c3b 100644 --- a/tests/lean_spec/subspecs/xmss/test_message_hash.py +++ b/tests/lean_spec/subspecs/xmss/test_message_hash.py @@ -2,7 +2,7 @@ Tests for the "Top Level" message hashing and encoding logic. """ -from lean_spec.subspecs.koalabear import Fp, P +from lean_spec.subspecs.koalabear import Fp from lean_spec.subspecs.xmss.constants import ( TEST_CONFIG, TWEAK_PREFIX_MESSAGE, diff --git a/tests/lean_spec/subspecs/xmss/test_prf.py b/tests/lean_spec/subspecs/xmss/test_prf.py index 924a715d..519f360c 100644 --- a/tests/lean_spec/subspecs/xmss/test_prf.py +++ b/tests/lean_spec/subspecs/xmss/test_prf.py @@ -1,6 +1,4 @@ -""" -Tests for the SHAKE128-based pseudorandom function (PRF). -""" +"""Tests for the SHAKE128-based pseudorandom function (PRF).""" from lean_spec.subspecs.xmss.constants import ( PRF_KEY_LENGTH,