diff --git a/src/lean_spec/subspecs/xmss/__init__.py b/src/lean_spec/subspecs/xmss/__init__.py index 9e381730..0303880e 100644 --- a/src/lean_spec/subspecs/xmss/__init__.py +++ b/src/lean_spec/subspecs/xmss/__init__.py @@ -5,18 +5,23 @@ It exposes the core data structures and the main interface functions. """ -from .constants import LIFETIME, MESSAGE_LENGTH -from .interface import key_gen, sign, verify -from .structures import HashTreeOpening, PublicKey, SecretKey, Signature +from .constants import PROD_CONFIG, TEST_CONFIG +from .containers import ( + HashTree, + HashTreeOpening, + PublicKey, + SecretKey, + Signature, +) +from .interface import GeneralizedXmssScheme __all__ = [ - "key_gen", - "sign", - "verify", + "GeneralizedXmssScheme", "PublicKey", "Signature", "SecretKey", "HashTreeOpening", - "LIFETIME", - "MESSAGE_LENGTH", + "HashTree", + "PROD_CONFIG", + "TEST_CONFIG", ] diff --git a/src/lean_spec/subspecs/xmss/constants.py b/src/lean_spec/subspecs/xmss/constants.py index e2f14bb0..2aeb6d36 100644 --- a/src/lean_spec/subspecs/xmss/constants.py +++ b/src/lean_spec/subspecs/xmss/constants.py @@ -1,8 +1,12 @@ """ -Defines the cryptographic constants for the XMSS specification. +Defines the cryptographic constants and configuration presets for the +XMSS spec. This specification corresponds to the "hashing-optimized" Top Level Target Sum -instantiation from the canonical Rust implementation. +instantiation from the canonical Rust implementation +(production instantiation). + +We also provide a test instantiation for testing purposes. .. note:: This specification uses the **KoalaBear** prime field, which is consistent @@ -14,88 +18,133 @@ specification in the future. """ +from pydantic import BaseModel, ConfigDict +from typing_extensions import Final + from ..koalabear import Fp -# ================================================================= -# Core Scheme Configuration -# ================================================================= -MESSAGE_LENGTH: int = 32 -"""The length in bytes for all messages to be signed.""" +class XmssConfig(BaseModel): + """A model holding the configuration constants for an XMSS preset.""" -LOG_LIFETIME: int = 32 -"""The base-2 logarithm of the scheme's maximum lifetime.""" + model_config = ConfigDict(frozen=True, extra="forbid") -LIFETIME: int = 1 << LOG_LIFETIME -""" -The maximum number of epochs supported by this configuration. + # --- Core Scheme Configuration --- + MESSAGE_LENGTH: int + """The length in bytes for all messages to be signed.""" -An individual key pair can be active for a smaller sub-range. -""" + LOG_LIFETIME: int + """The base-2 logarithm of the scheme's maximum lifetime.""" + @property + def LIFETIME(self) -> int: # noqa: N802 + """ + The maximum number of epochs supported by this configuration. -# ================================================================= -# Target Sum WOTS Parameters -# ================================================================= + An individual key pair can be active for a smaller sub-range. + """ + return 1 << self.LOG_LIFETIME -DIMENSION: int = 64 -"""The total number of hash chains, `v`.""" + DIMENSION: int + """The total number of hash chains, `v`.""" -BASE: int = 8 -"""The alphabet size for the digits of the encoded message.""" + BASE: int + """The alphabet size for the digits of the encoded message.""" -FINAL_LAYER: int = 77 -"""The number of top layers of the hypercube to map the hash output into.""" + FINAL_LAYER: int + """Number of top layers of the hypercube to map the hash output into.""" -TARGET_SUM: int = 375 -"""The required sum of all codeword chunks for a signature to be valid.""" + TARGET_SUM: int + """The required sum of all codeword chunks for a signature to be valid.""" + MAX_TRIES: int + """ + How often one should try at most to resample a random value. -# ================================================================= -# Hash and Encoding Length Parameters (in field elements) -# ================================================================= + This is currently based on experiments with the Rust implementation. + Should probably be modified in production. + """ -PARAMETER_LEN: int = 5 -""" -The length of the public parameter `P`. + PARAMETER_LEN: int + """ + The length of the public parameter `P`. -It is used to specialize the hash function. -""" + It is used to specialize the hash function. + """ -TWEAK_LEN_FE: int = 2 -"""The length of a domain-separating tweak.""" + TWEAK_LEN_FE: int + """The length of a domain-separating tweak.""" -MSG_LEN_FE: int = 9 -"""The length of a message after being encoded into field elements.""" + MSG_LEN_FE: int + """The length of a message after being encoded into field elements.""" -RAND_LEN_FE: int = 7 -"""The length of the randomness `rho` used during message encoding.""" + RAND_LEN_FE: int + """The length of the randomness `rho` used during message encoding.""" -HASH_LEN_FE: int = 8 -"""The output length of the main tweakable hash function.""" + HASH_LEN_FE: int + """The output length of the main tweakable hash function.""" -CAPACITY: int = 9 -"""The capacity of the Poseidon2 sponge, defining its security level.""" + CAPACITY: int + """The capacity of the Poseidon2 sponge, defining its security level.""" -POS_OUTPUT_LEN_PER_INV_FE: int = 15 -"""Output length per invocation for the message hash.""" + POS_OUTPUT_LEN_PER_INV_FE: int + """Output length per invocation for the message hash.""" -POS_INVOCATIONS: int = 1 -"""Number of invocations for the message hash.""" + POS_INVOCATIONS: int + """Number of invocations for the message hash.""" -POS_OUTPUT_LEN_FE: int = POS_OUTPUT_LEN_PER_INV_FE * POS_INVOCATIONS -"""Total output length for the message hash.""" + @property + def POS_OUTPUT_LEN_FE(self) -> int: # noqa: N802 + """Total output length for the message hash.""" + return self.POS_OUTPUT_LEN_PER_INV_FE * self.POS_INVOCATIONS -# ================================================================= -# Domain Separator Prefixes for Tweaks -# ================================================================= +PROD_CONFIG: Final = XmssConfig( + MESSAGE_LENGTH=32, + LOG_LIFETIME=32, + DIMENSION=64, + BASE=8, + FINAL_LAYER=77, + TARGET_SUM=375, + MAX_TRIES=100_000, + PARAMETER_LEN=5, + TWEAK_LEN_FE=2, + MSG_LEN_FE=9, + RAND_LEN_FE=7, + HASH_LEN_FE=8, + CAPACITY=9, + POS_OUTPUT_LEN_PER_INV_FE=15, + POS_INVOCATIONS=1, +) -TWEAK_PREFIX_CHAIN = Fp(value=0x00) + +TEST_CONFIG: Final = XmssConfig( + MESSAGE_LENGTH=32, + LOG_LIFETIME=8, + DIMENSION=16, + BASE=4, + FINAL_LAYER=24, + TARGET_SUM=24, + MAX_TRIES=100_000, + PARAMETER_LEN=5, + TWEAK_LEN_FE=2, + MSG_LEN_FE=9, + RAND_LEN_FE=7, + HASH_LEN_FE=8, + CAPACITY=9, + POS_OUTPUT_LEN_PER_INV_FE=15, + POS_INVOCATIONS=1, +) + + +TWEAK_PREFIX_CHAIN: Final = Fp(value=0x00) """The unique prefix for tweaks used in Winternitz-style hash chains.""" -TWEAK_PREFIX_TREE = Fp(value=0x01) +TWEAK_PREFIX_TREE: Final = Fp(value=0x01) """The unique prefix for tweaks used when hashing Merkle tree nodes.""" -TWEAK_PREFIX_MESSAGE = Fp(value=0x02) +TWEAK_PREFIX_MESSAGE: Final = Fp(value=0x02) """The unique prefix for tweaks used in the initial message hashing step.""" + +PRF_KEY_LENGTH: int = 32 +"""The length of the PRF secret key in bytes.""" diff --git a/src/lean_spec/subspecs/xmss/containers.py b/src/lean_spec/subspecs/xmss/containers.py new file mode 100644 index 00000000..6a648dd7 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/containers.py @@ -0,0 +1,110 @@ +"""Defines the data containers for the Generalized XMSS signature scheme.""" + +from typing import Annotated, List + +from pydantic import BaseModel, ConfigDict, Field + +from ..koalabear import Fp +from .constants import PRF_KEY_LENGTH + +PRFKey = Annotated[ + bytes, Field(min_length=PRF_KEY_LENGTH, max_length=PRF_KEY_LENGTH) +] +""" +A type alias for the PRF secret key. + +It is a byte string of `PRF_KEY_LENGTH` bytes. +""" + + +HashDigest = List[Fp] +""" +A type alias representing a hash digest. +""" + +Parameter = List[Fp] +""" +A type alias representing the public parameter `P`. +""" + +Randomness = List[Fp] +""" +A type alias representing the randomness `rho`. +""" + + +class HashTreeOpening(BaseModel): + """ + A Merkle authentication path. + + It contains a list of sibling nodes required to reconstruct the path + from a leaf node up to the Merkle root. + """ + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + siblings: List[HashDigest] = Field( + ..., description="List of sibling hashes, from bottom to top." + ) + + +class HashTreeLayer(BaseModel): + """ + Represents a single layer within the sparse Merkle tree. + + Attributes: + start_index: The index of the first node in this layer within the full + conceptual tree. + nodes: A list of the actual hash digests stored for this layer. + """ + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + start_index: int + """The starting index of the first node in this layer.""" + nodes: List[HashDigest] + """A list of the actual hash digests stored for this layer.""" + + +class HashTree(BaseModel): + """ + The complete sparse Merkle tree structure. + + Attributes: + depth: The total depth of the tree (e.g., 32 for a 2^32 leaf space). + layers: A list of `HashTreeLayer` objects, from the leaf hashes + (layer 0) up to the layer just below the root. + """ + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + depth: int + """The total depth of the tree (e.g., 32 for a 2^32 leaf space).""" + layers: List[HashTreeLayer] + """""A list of `HashTreeLayer` objects, from the leaf hashes + (layer 0) up to the layer just below the root.""" + + +class PublicKey(BaseModel): + """The public key for the Generalized XMSS scheme.""" + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + root: List[Fp] + parameter: Parameter + + +class Signature(BaseModel): + """A signature in the Generalized XMSS scheme.""" + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + path: HashTreeOpening + rho: Randomness + hashes: List[HashDigest] + + +class SecretKey(BaseModel): + """The secret key for the Generalized XMSS scheme.""" + + model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) + prf_key: PRFKey + tree: HashTree + parameter: Parameter + activation_epoch: int + num_active_epochs: int diff --git a/src/lean_spec/subspecs/xmss/hypercube.py b/src/lean_spec/subspecs/xmss/hypercube.py new file mode 100644 index 00000000..c76521f2 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/hypercube.py @@ -0,0 +1,254 @@ +""" +Implements the mathematical operations for hypercube layers. + +This module provides the necessary functions to work with vertices in a +v-dimensional hypercube with coordinates in the range [0, w-1]. A key concept +is the partitioning of the hypercube's vertices into "layers". A vertex belongs +to layer `d`, where `d` is its distance from the sink vertex `(w-1, ..., w-1)`. + +The core functionalities are: +1. **Precomputation and Caching**: Computes and caches the sizes + of each layer for different hypercube configurations (`w` and `v`). +2. **Mapping**: Provides bijective mappings between an integer index within a + layer and the unique vertex (a list of coordinates) it represents. + +This logic is a direct translation of the algorithms described in the paper +"At the top of the hypercube" (eprint 2025/889) +and the reference Rust implementation: https://github.com/b-wagn/hash-sig. +""" + +from __future__ import annotations + +import bisect +from functools import lru_cache +from typing import List, Tuple + +from pydantic import BaseModel, ConfigDict + +MAX_DIMENSION = 100 +"""The maximum dimension `v` for which layer sizes will be precomputed.""" + + +class LayerInfo(BaseModel): + """ + Stores the precomputed sizes and cumulative sums for a + hypercube's layers. + """ + + model_config = ConfigDict(frozen=True) + sizes: List[int] + """The number of vertices in each layer `d`.""" + prefix_sums: List[int] + """ + The cumulative number of vertices up to and including layer `d`. + + `prefix_sums[d] = sizes[0] + ... + sizes[d]`. + """ + + def sizes_sum_in_range(self, start: int, end: int) -> int: + """Calculates the sum of sizes in an inclusive range [start, end].""" + if start > end: + return 0 + if start == 0: + return self.prefix_sums[end] + else: + return self.prefix_sums[end] - self.prefix_sums[start - 1] + + +@lru_cache(maxsize=None) +def prepare_layer_info(w: int) -> List[LayerInfo]: + """ + Precomputes and caches the number of vertices in each layer of a hypercube. + + It calculates the size of every layer for hypercubes with a given + base `w` (where coordinates are in `[0, w-1]`) for all dimensions + `v` up to `MAX_DIMENSION`. + + This precomputation is based on the recurrence relation from + Lemma 8 of the paper "At the top of the hypercube" (eprint 2025/889). + + Args: + w: The base of the hypercube. + + Returns: + A list where the element at index `v` is a `LayerInfo` object + containing the layer sizes for a `v`-dimensional hypercube. + """ + # Initialize a list to store the results for each dimension `v`. + # + # Index 0 is unused to allow for direct indexing, e.g., `all_info[v]`. + all_info = [LayerInfo(sizes=[], prefix_sums=[])] * (MAX_DIMENSION + 1) + + # BASE CASE + # + # For a 1-dimensional hypercube (v=1), which is just a line of `w` points + # with coordinates [0], [1], ..., [w-1]. + # + # The distance `d` from the sink `[w-1]` is simply `(w-1) - coordinate`. + + # Each of the `w` possible layers contains exactly one vertex. + dim1_sizes = [1] * w + # The prefix sums (cumulative sizes) are therefore just [1, 2, 3, ..., w]. + dim1_prefix_sums = list(range(1, w + 1)) + # Store the result for v=1, which will seed the inductive step. + all_info[1] = LayerInfo(sizes=dim1_sizes, prefix_sums=dim1_prefix_sums) + + # Now, build the layer info for all higher dimensions up to the maximum. + for v in range(2, MAX_DIMENSION + 1): + # The maximum possible distance `d` in a v-dimensional hypercube. + max_d = (w - 1) * v + # Retrieve the already-computed data for the previous dimension (v-1). + prev_layer_info = all_info[v - 1] + + # This list will store the computed size of each layer `d` + # for dimension `v`. + current_sizes: List[int] = [] + for d in range(max_d + 1): + # Implements the recurrence l_d(v) = Σ l_{d-j}(v-1) from the paper. + # `j` is one coordinate's distance contribution. + + # Calculate the valid range [j_min, j_max] for `j`. + j_min = max(0, d - (w - 1) * (v - 1)) + j_max = min(w - 1, d) + + # Translate the sum over `j` to an index range `k`, + # where k = d - j. + # + # This allows for an efficient lookup using prefix sums. + k_min = d - j_max + k_max = d - j_min + + # Calculate the sum using the precomputed prefix sums + # from the previous dimension's `LayerInfo`. + layer_size = prev_layer_info.sizes_sum_in_range(k_min, k_max) + current_sizes.append(layer_size) + + # After computing all layer sizes for dimension `v`, we compute their + # prefix sums. + # + # This is needed for the *next* iteration (for dimension v+1). + current_prefix_sums: List[int] = [] + current_sum = 0 + for size in current_sizes: + current_sum += size + current_prefix_sums.append(current_sum) + + # Store the complete layer info for the current dimension `v`. + all_info[v] = LayerInfo( + sizes=current_sizes, prefix_sums=current_prefix_sums + ) + + # Return the complete table of layer information for the given base `w`. + return all_info + + +def get_layer_size(w: int, v: int, d: int) -> int: + """Returns the size of a specific layer `d` in a `(w, v)` hypercube.""" + return prepare_layer_info(w)[v].sizes[d] + + +def hypercube_part_size(w: int, v: int, d: int) -> int: + """Returns the total size of layers 0 to `d` (inclusive).""" + return prepare_layer_info(w)[v].prefix_sums[d] + + +def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]: + """ + Given a global index `x`, finds the layer `d` it belongs to and its + local index (`remainder`) within that layer. + + Args: + w: The hypercube base. + v: The hypercube dimension. + x: The global index of a vertex (from 0 to w**v - 1). + + Returns: + A tuple `(d, remainder)`. + """ + prefix_sums = prepare_layer_info(w)[v].prefix_sums + # Use binary search to efficiently find the correct layer. + d = bisect.bisect_left(prefix_sums, x + 1) + + if d == 0: + # `x` is in the very first layer (d=0). + # + # The remainder is `x` itself, as the cumulative size of + # preceding layers is zero. + remainder = x + else: + # The cumulative size of all layers up to `d-1` is + # at `prefix_sums[d - 1]`. + # + # The remainder is `x` minus this cumulative size. + remainder = x - prefix_sums[d - 1] + + return d, remainder + + +def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]: + """ + Maps an integer index `x` to a unique vertex in a specific hypercube layer. + + This function provides a bijective mapping from an integer `x` + (derived from a hash) to a unique list of `v` coordinates, + `[a_0, ..., a_{v-1}]`. + + The algorithm works iteratively, determining one coordinate at a time. + + Args: + w: The hypercube base (coordinates are in `[0, w-1]`). + v: The hypercube dimension (the number of coordinates). + d: The target layer, defined by its distance from the sink vertex. + x: The integer index within the layer `d`, must be `0 <= x < size(d)`. + + Returns: + A list of `v` integers representing the coordinates of the vertex. + """ + layer_info_cache = prepare_layer_info(w) + + # Validate that the input index `x` is valid for the target layer. + layer_size = layer_info_cache[v].sizes[d] + if x >= layer_size: + raise ValueError("Index x is out of bounds for the given layer.") + + vertex: List[int] = [] + # Track remaining distance and index. + d_curr, x_curr = d, x + + # Determine each of the first v-1 coordinates iteratively. + for i in range(1, v): + dim_remaining = v - i + prev_dim_layer_info = layer_info_cache[dim_remaining] + + # This loop finds which block of sub-hypercubes the index `x_curr` + # falls into. + # + # It skips over full blocks by subtracting their size + # from `x_curr` until the correct one is found. + ji = -1 # Sentinel value + range_start = max(0, d_curr - (w - 1) * dim_remaining) + for j in range(range_start, min(w, d_curr + 1)): + count = prev_dim_layer_info.sizes[d_curr - j] + if x_curr >= count: + x_curr -= count + else: + ji = j # Found the correct block. + break + + if ji == -1: + raise RuntimeError( + "Internal logic error: failed to find coordinate" + ) + + # Convert the block's distance contribution `ji` to a coordinate `ai`. + ai = w - 1 - ji + vertex.append(ai) + + # Update the remaining distance for the next, smaller subproblem. + d_curr -= ji + + # The final coordinate is uniquely determined by the remaining values. + last_coord = w - 1 - x_curr - d_curr + vertex.append(last_coord) + + return vertex diff --git a/src/lean_spec/subspecs/xmss/interface.py b/src/lean_spec/subspecs/xmss/interface.py index e8e7e492..06107db4 100644 --- a/src/lean_spec/subspecs/xmss/interface.py +++ b/src/lean_spec/subspecs/xmss/interface.py @@ -1,109 +1,340 @@ """ Defines the core interface for the Generalized XMSS signature scheme. -Specification for the high-level functions (`key_gen`, `sign`, `verify`) -that constitute the public API of the signature scheme. For the purpose of this -specification, these are defined as placeholders with detailed documentation. +Specification for the high-level functions (`key_gen`, `sign`, `verify`). + +This constitutes the public API of the signature scheme. """ from __future__ import annotations -from typing import Tuple - -from .structures import PublicKey, SecretKey, Signature - - -def key_gen( - activation_epoch: int, num_active_epochs: int -) -> Tuple[PublicKey, SecretKey]: - """ - Generates a new cryptographic key pair. This is a **randomized** algorithm. - - This function is a placeholder. In a real implementation, it would involve - generating a master secret, deriving all one-time keys, and constructing - the full Merkle tree. - - Args: - activation_epoch: The starting epoch for which this key is active. - num_active_epochs: The number of consecutive epochs - the key is active for. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig - """ - raise NotImplementedError( - "key_gen is not part of this specification. " - "See the Rust reference implementation." - ) - - -def sign(sk: SecretKey, epoch: int, message: bytes) -> Signature: - """ - Produces a digital signature for a given message at a specific epoch. This - is a **randomized** algorithm. - - This function is a placeholder. The signing process involves encoding the - message, generating a one-time signature, and providing a Merkle path. - - **CRITICAL**: This function must never be called twice with the same secret - key and epoch for different messages, as this would compromise security. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig - """ - raise NotImplementedError( - "sign is not part of this specification. " - "See the Rust reference implementation." - ) - - -def verify(pk: PublicKey, epoch: int, message: bytes, sig: Signature) -> bool: - r""" - Verifies a digital signature against a public key, message, and epoch. This - is a **deterministic** algorithm. - - This function is a placeholder. The complete verification logic is detailed - below and will be implemented in a future update. - - ### Verification Algorithm - - 1. **Re-encode Message**: The verifier uses the randomness `rho` from the - signature to re-compute the codeword $x = (x_1, \dots, x_v)$ from the - message `m`. - This includes calculating the checksum or checking the target sum. - - 2. **Reconstruct One-Time Public Key**: For each intermediate hash $y_i$ - in the signature, the verifier completes the corresponding hash chain. - Since $y_i$ was computed with $x_i$ steps, the verifier applies the - hash function an additional $w - 1 - x_i$ times to arrive at the - one-time public key component $pk_{ep,i}$. - - 3. **Compute Merkle Leaf**: The verifier hashes the reconstructed one-time - public key components to compute the expected Merkle leaf for `epoch`. - - 4. **Verify Merkle Path**: The verifier uses the `path` from the signature - to compute a candidate Merkle root starting from the computed leaf. - Verification succeeds if and only if this candidate root matches the - `root` in the `PublicKey`. - - Args: - pk: The public key to verify against. - epoch: The epoch the signature corresponds to. - message: The message that was supposedly signed. - sig: The signature object to be verified. - - Returns: - `True` if the signature is valid, `False` otherwise. - - For the formal specification of this process, please refer to: - - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 - - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 - - The canonical Rust implementation: https://github.com/b-wagn/hash-sig - """ - raise NotImplementedError( - "verify will be implemented in a future update to the specification." - ) +from typing import List, Tuple + +from lean_spec.subspecs.xmss.target_sum import ( + PROD_TARGET_SUM_ENCODER, + TEST_TARGET_SUM_ENCODER, + TargetSumEncoder, +) + +from .constants import ( + PROD_CONFIG, + TEST_CONFIG, + XmssConfig, +) +from .containers import HashDigest, PublicKey, SecretKey, Signature +from .merkle_tree import ( + PROD_MERKLE_TREE, + TEST_MERKLE_TREE, + MerkleTree, +) +from .prf import PROD_PRF, TEST_PRF, Prf +from .tweak_hash import ( + PROD_TWEAK_HASHER, + TEST_TWEAK_HASHER, + TreeTweak, + TweakHasher, +) +from .utils import PROD_RAND, TEST_RAND, Rand + + +class GeneralizedXmssScheme: + """Instance of the Generalized XMSS signature scheme for a given config.""" + + def __init__( + self, + config: XmssConfig, + prf: Prf, + hasher: TweakHasher, + merkle_tree: MerkleTree, + encoder: TargetSumEncoder, + rand: Rand, + ): + """Initializes the scheme with a specific parameter set.""" + self.config = config + self.prf = prf + self.hasher = hasher + self.merkle_tree = merkle_tree + self.encoder = encoder + self.rand = rand + + def key_gen( + self, activation_epoch: int, num_active_epochs: int + ) -> Tuple[PublicKey, SecretKey]: + """ + Generates a new cryptographic key pair. + + This is a **randomized** algorithm. + + This function executes the full key generation process: + 1. Generates a master secret (PRF key) and a public hash parameter. + 2. For each epoch in the active range, it uses the PRF to derive the + secret starting points for all `DIMENSION` hash chains. + 3. It computes the public endpoint of each chain by + hashing `BASE - 1` times. + 4. The list of all chain endpoints for an epoch forms the + one-time public key. + This one-time public key is hashed to create a single Merkle leaf. + 5. A Merkle tree is built over all generated leaves, and its root + becomes part of the final public key. + + Args: + activation_epoch: The starting epoch for which this key is active. + num_active_epochs: The number of consecutive epochs + the key is active for. + + For the formal specification of this process, please refer to: + - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 + - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 + - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + """ + # Get the config for this scheme. + config = self.config + + # Validate the activation range against the scheme's total lifetime. + if activation_epoch + num_active_epochs > config.LIFETIME: + raise ValueError("Activation range exceeds the key's lifetime.") + + # Generate the random public parameter `P` and the master PRF key. + parameter = self.rand.parameter() + prf_key = self.prf.key_gen() + + # For each epoch, generate the corresponding Merkle leaf hash. + leaf_hashes: List[HashDigest] = [] + for epoch in range( + activation_epoch, activation_epoch + num_active_epochs + ): + # For each epoch, we compute `DIMENSION` chain endpoints. + chain_ends: List[HashDigest] = [] + for chain_index in range(config.DIMENSION): + # Derive the secret start of the chain from the master key. + start_digest = self.prf.apply(prf_key, epoch, chain_index) + # Compute the public end of the chain by hashing + # BASE - 1 times. + end_digest = self.hasher.hash_chain( + parameter=parameter, + epoch=epoch, + chain_index=chain_index, + start_step=0, + num_steps=config.BASE - 1, + start_digest=start_digest, + ) + chain_ends.append(end_digest) + + # The Merkle leaf is the hash of all chain endpoints + # for this epoch. + leaf_tweak = TreeTweak(level=0, index=epoch) + leaf_hash = self.hasher.apply(parameter, leaf_tweak, chain_ends) + leaf_hashes.append(leaf_hash) + + # Build the Merkle tree over the generated leaves. + tree = self.merkle_tree.build( + config.LOG_LIFETIME, activation_epoch, parameter, leaf_hashes + ) + root = self.merkle_tree.root(tree) + + # Assemble and return the public and secret keys. + pk = PublicKey(root=root, parameter=parameter) + sk = SecretKey( + prf_key=prf_key, + tree=tree, + parameter=parameter, + activation_epoch=activation_epoch, + num_active_epochs=num_active_epochs, + ) + return pk, sk + + def sign(self, sk: SecretKey, epoch: int, message: bytes) -> Signature: + """ + Produces a digital signature for a given message at a specific epoch. + + This is a **randomized** algorithm. + + **CRITICAL**: This function must never be called twice with the same + secret key and epoch for different messages, as this + would compromise security. + + The signing process involves: + 1. Repeatedly attempting to encode the message with + fresh randomness (`rho`) until a valid codeword is found. + 2. Computing the one-time signature, which consists of + intermediate values from the secret hash chains, + determined by the digits of the codeword. + 3. Retrieving the Merkle authentication path for the given epoch. + + For the formal specification of this process, please refer to: + - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 + - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 + - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + """ + # Get the config for this scheme. + config = self.config + + # Check that the key is active for the requested signing epoch. + active_range = range( + sk.activation_epoch, sk.activation_epoch + sk.num_active_epochs + ) + if epoch not in active_range: + raise ValueError("Key is not active for the specified epoch.") + + # Find a valid message encoding by trying different randomness `rho`. + codeword = None + rho = None + for _ in range(config.MAX_TRIES): + # Sample a randomness `rho` and try to encode the message. + current_rho = self.rand.rho() + current_codeword = self.encoder.encode( + sk.parameter, message, current_rho, epoch + ) + + # If a valid codeword is found, break out of the loop. + if current_codeword is not None: + codeword = current_codeword + rho = current_rho + break + + # If no valid encoding is found after many tries, signing fails. + if codeword is None or rho is None: + raise RuntimeError("Failed to find a valid message encoding.") + + # Sanity check: the encoding must have the correct number of chunks. + if len(codeword) != self.config.DIMENSION: + raise RuntimeError( + "Encoding is broken: returned too many or too few chunks." + ) + + # Compute the one-time signature hashes based on the codeword. + ots_hashes: List[HashDigest] = [] + for chain_index, steps in enumerate(codeword): + # Derive the secret start of the chain from the master key. + start_digest = self.prf.apply(sk.prf_key, epoch, chain_index) + # Walk the chain for the number of steps given + # by the codeword digit. + ots_digest = self.hasher.hash_chain( + parameter=sk.parameter, + epoch=epoch, + chain_index=chain_index, + start_step=0, + num_steps=steps, + start_digest=start_digest, + ) + ots_hashes.append(ots_digest) + + # Get the Merkle authentication path for the current epoch. + path = self.merkle_tree.path(sk.tree, epoch) + + # Assemble and return the final signature. + return Signature(path=path, rho=rho, hashes=ots_hashes) + + def verify( + self, pk: PublicKey, epoch: int, message: bytes, sig: Signature + ) -> bool: + r""" + Verifies a digital signature against a public key, message, and epoch. + + This is a **deterministic** algorithm. + + ### Verification Algorithm + + 1. **Re-encode Message**: The verifier uses the randomness `rho` + from the signature to re-compute the codeword + $x = (x_1, \dots, x_v)$ from the message `m`. + + This includes calculating the checksum or checking the target sum. + + 2. **Reconstruct One-Time Public Key**: For each intermediate + hash $y_i$ in the signature, the verifier completes the + corresponding hash chain. + + Since $y_i$ was computed with $x_i$ steps, the verifier applies the + hash function an additional $w - 1 - x_i$ times to arrive at the + one-time public key component $pk_{ep,i}$. + + 3. **Compute Merkle Leaf**: The verifier hashes the reconstructed + one-time public key components to compute the expected Merkle + leaf for `epoch`. + + 4. **Verify Merkle Path**: The verifier uses the `path` from + the signature to compute a candidate Merkle root starting from + the computed leaf. + + Verification succeeds if and only if this candidate root matches + the `root` in the `PublicKey`. + + Args: + pk: The public key to verify against. + epoch: The epoch the signature corresponds to. + message: The message that was supposedly signed. + sig: The signature object to be verified. + + Returns: + `True` if the signature is valid, `False` otherwise. + + For the formal specification of this process, please refer to: + - "Hash-Based Multi-Signatures for Post-Quantum Ethereum": https://eprint.iacr.org/2025/055 + - "Technical Note: LeanSig for Post-Quantum Ethereum": https://eprint.iacr.org/2025/1332 + - The canonical Rust implementation: https://github.com/b-wagn/hash-sig + """ + # Get the config for this scheme. + config = self.config + + # Check that the signature is for a valid epoch. + if epoch > self.config.LIFETIME: + raise ValueError("The signature is for a future epoch.") + + # Re-encode the message using the randomness `rho` from the signature. + # + # If the encoding is invalid, the signature is invalid. + codeword = self.encoder.encode(pk.parameter, message, sig.rho, epoch) + if codeword is None: + return False + + # Reconstruct the one-time public key (the list of chain endpoints). + chain_ends: List[HashDigest] = [] + for chain_index, xi in enumerate(codeword): + # The signature provides the hash value after `xi` steps. + start_digest = sig.hashes[chain_index] + # We must perform the remaining `BASE - 1 - xi` steps + # to get to the end. + num_steps_remaining = config.BASE - 1 - xi + end_digest = self.hasher.hash_chain( + parameter=pk.parameter, + epoch=epoch, + chain_index=chain_index, + start_step=xi, + num_steps=num_steps_remaining, + start_digest=start_digest, + ) + chain_ends.append(end_digest) + + # Verify the Merkle path. + # + # This function internally hashes `chain_ends` to get the leaf node + # and then climbs the tree to recompute the root. + return self.merkle_tree.verify_path( + parameter=pk.parameter, + root=pk.root, + position=epoch, + leaf_parts=chain_ends, + opening=sig.path, + ) + + +PROD_SIGNATURE_SCHEME = GeneralizedXmssScheme( + PROD_CONFIG, + PROD_PRF, + PROD_TWEAK_HASHER, + PROD_MERKLE_TREE, + PROD_TARGET_SUM_ENCODER, + PROD_RAND, +) +"""An instance configured for production-level parameters.""" + +TEST_SIGNATURE_SCHEME = GeneralizedXmssScheme( + TEST_CONFIG, + TEST_PRF, + TEST_TWEAK_HASHER, + TEST_MERKLE_TREE, + TEST_TARGET_SUM_ENCODER, + TEST_RAND, +) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/merkle_tree.py b/src/lean_spec/subspecs/xmss/merkle_tree.py new file mode 100644 index 00000000..9b49bf07 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/merkle_tree.py @@ -0,0 +1,274 @@ +""" +Implements the sparse Merkle tree used in the Generalized XMSS scheme. + +This module provides the data structures and algorithms for building a Merkle +tree over a contiguous subset of leaves, computing authentication paths, and +verifying those paths. + +The key features are: +1. **Sparsity**: The tree can represent a massive address space (e.g., 2^32 + leaves) while only storing the nodes relevant to a smaller, active range of + leaves (e.g., 2^20 leaves). +2. **Random Padding**: To simplify the logic for pairing sibling nodes, the + active layers are padded with random hash values to ensure they always + start at an even index and end at an odd index. +""" + +from __future__ import annotations + +from typing import List + +from lean_spec.subspecs.xmss.constants import ( + PROD_CONFIG, + TEST_CONFIG, + XmssConfig, +) + +from .containers import ( + HashDigest, + HashTree, + HashTreeLayer, + HashTreeOpening, + Parameter, +) +from .tweak_hash import ( + PROD_TWEAK_HASHER, + TEST_TWEAK_HASHER, + TreeTweak, + TweakHasher, +) +from .utils import PROD_RAND, TEST_RAND, Rand + + +class MerkleTree: + """An instance of the Merkle Tree handler for a given config.""" + + def __init__(self, config: XmssConfig, hasher: TweakHasher, rand: Rand): + """Initializes with a config, a hasher, and a random generator.""" + self.config = config + self.hasher = hasher + self.rand = rand + + def _get_padded_layer( + self, nodes: List[HashDigest], start_index: int + ) -> HashTreeLayer: + """ + Pads a layer to ensure its nodes can always be paired up. + + This helper function adds random padding to the start and/or end of + a list of nodes to enforce an invariant: every layer must start at an + even index and end at an odd index. This guarantees that every node has + a sibling, simplifying the construction of the next layer up. + + Args: + nodes: The list of active nodes for the current layer. + start_index: The starting index of the first node in `nodes`. + + Returns: + A new `HashTreeLayer` with padding applied. + """ + nodes_with_padding: List[HashDigest] = [] + end_index = start_index + len(nodes) - 1 + + # Prepend random padding if the layer starts at an odd index. + if start_index % 2 == 1: + nodes_with_padding.append(self.rand.domain()) + + # The actual start index of the padded layer is always the even + # number at or immediately before the original start_index. + actual_start_index = start_index - (start_index % 2) + + # Add the actual node content. + nodes_with_padding.extend(nodes) + + # Append random padding if the layer ends at an even index. + if end_index % 2 == 0: + nodes_with_padding.append(self.rand.domain()) + + return HashTreeLayer( + start_index=actual_start_index, nodes=nodes_with_padding + ) + + def build( + self, + depth: int, + start_index: int, + parameter: Parameter, + leaf_hashes: List[HashDigest], + ) -> HashTree: + """ + Builds a new sparse Merkle tree from a list of leaf hashes. + + The construction proceeds bottom-up, from the leaf layer to the root. + At each level, pairs of sibling nodes are hashed to create + their parents for the next level up. + + Args: + depth: The depth of the tree (e.g., 32 for a 2^32 leaf space). + start_index: The index of the first leaf in `leaf_hashes`. + parameter: The public parameter `P` for the hash function. + leaf_hashes: The list of pre-hashed leaf nodes to + build the tree on. + + Returns: + The fully constructed `HashTree` object. + """ + # Check there is enough space for the leafs in the tree. + if start_index + len(leaf_hashes) > 2**depth: + raise ValueError("Not enough space for leafs in the tree.") + + # Start with the leaf hashes and apply the initial padding. + layers: List[HashTreeLayer] = [] + current_layer = self._get_padded_layer(leaf_hashes, start_index) + layers.append(current_layer) + + # Iterate from the leaf layer (level 0) up to the root. + for level in range(depth): + parents: List[HashDigest] = [] + # Group the current layer's nodes into pairs of siblings. + for i, children in enumerate( + zip( + current_layer.nodes[0::2], + current_layer.nodes[1::2], + strict=False, + ) + ): + # Calculate the position of the parent node in the + # next level up. + parent_index = (current_layer.start_index // 2) + i + # Create the tweak for hashing these two children. + tweak = TreeTweak(level=level + 1, index=parent_index) + # Hash the left and right children to get their parent. + parent_node = self.hasher.apply( + parameter, tweak, list(children) + ) + parents.append(parent_node) + + # Pad the new list of parents to prepare for the next iteration. + new_start_index = current_layer.start_index // 2 + current_layer = self._get_padded_layer(parents, new_start_index) + layers.append(current_layer) + + return HashTree(depth=depth, layers=layers) + + def root(self, tree: HashTree) -> HashDigest: + """Extracts the root digest from a constructed Merkle tree.""" + # The root is the single node in the final layer. + return tree.layers[-1].nodes[0] + + def path(self, tree: HashTree, position: int) -> HashTreeOpening: + """ + Computes the Merkle authentication path for a leaf at a given position. + + The path consists of the list of sibling nodes required to reconstruct + the root, starting from the leaf's sibling and going up the tree. + + Args: + tree: The `HashTree` from which to extract the path. + position: The absolute index of the leaf for which to + generate path. + + Returns: + A `HashTreeOpening` object containing the co-path. + """ + # Check that there is at least one layer in the tree. + if len(tree.layers) == 0: + raise ValueError("Cannot generate path for empty tree.") + + # Check that the position is within the tree's range. + if position < tree.layers[0].start_index: + raise ValueError("Position (before start) is invalid.") + + if position >= tree.layers[0].start_index + len(tree.layers[0].nodes): + raise ValueError("Position (after end) is invalid.") + + co_path: List[HashDigest] = [] + current_position = position + + # Iterate from the bottom layer (level 0) up to the layer + # below the root. + for level in range(tree.depth): + # Determine the sibling's position using an XOR operation. + sibling_position = current_position ^ 1 + # Find the sibling's index within the sparse `nodes` vector. + layer = tree.layers[level] + sibling_index_in_vec = sibling_position - layer.start_index + # Add the sibling's hash to the co-path. + co_path.append(layer.nodes[sibling_index_in_vec]) + # Move up to the parent's position for the next iteration. + current_position //= 2 + + return HashTreeOpening(siblings=co_path) + + def verify_path( + self, + parameter: Parameter, + root: HashDigest, + position: int, + leaf_parts: List[HashDigest], + opening: HashTreeOpening, + ) -> bool: + """ + Verifies a Merkle authentication path. + + This function reconstructs a candidate root by starting with the leaf + node and repeatedly hashing it with the sibling nodes provided + in the opening path. + + The verification succeeds if the candidate root matches the true root. + + Args: + parameter: The public parameter `P` for the hash function. + root: The known, trusted Merkle root. + position: The absolute index of the leaf being verified. + leaf_parts: The list of digests that constitute the original leaf. + opening: The `HashTreeOpening` object containing the sibling path. + + Returns: + `True` if the path is valid, `False` otherwise. + """ + # Compute the depth + depth = len(opening.siblings) + # Compute the number of leafs in the tree + num_leafs = 2**depth + # Check that the tree depth is at most 32. + if len(opening.siblings) > 32: + raise ValueError("Tree depth must be at most 32.") + # Check that the position and path length match. + if position >= num_leafs: + raise ValueError("Position and path length do not match.") + + # The first step is to hash the constituent parts of the leaf to get + # the actual node at layer 0 of the tree. + leaf_tweak = TreeTweak(level=0, index=position) + current_node = self.hasher.apply(parameter, leaf_tweak, leaf_parts) + + # Iterate up the tree, hashing the current node with its sibling from + # the path at each level. + current_position = position + for level, sibling_node in enumerate(opening.siblings): + # Determine if the current node is a left or right child. + if current_position % 2 == 0: + # Current node is a left child; sibling is on the right. + children = [current_node, sibling_node] + else: + # Current node is a right child; sibling is on the left. + children = [sibling_node, current_node] + + # Move up to the parent's position for the next iteration. + current_position //= 2 + # Create the tweak for the parent's level and position. + parent_tweak = TreeTweak(level=level + 1, index=current_position) + # Hash the children to compute the parent node. + current_node = self.hasher.apply(parameter, parent_tweak, children) + + # After iterating through the entire path, the final computed node + # should be the root of the tree. + return current_node == root + + +PROD_MERKLE_TREE = MerkleTree(PROD_CONFIG, PROD_TWEAK_HASHER, PROD_RAND) +"""An instance configured for production-level parameters.""" + +TEST_MERKLE_TREE = MerkleTree(TEST_CONFIG, TEST_TWEAK_HASHER, TEST_RAND) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/message_hash.py b/src/lean_spec/subspecs/xmss/message_hash.py new file mode 100644 index 00000000..fde28537 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/message_hash.py @@ -0,0 +1,156 @@ +""" +Defines the "Top Level" message hashing for the signature scheme. + +This module implements the logic for hashing a message into the top layers of +a hypercube, which is the first step in the "Top Level Target Sum" encoding. + +The process involves: +1. Encoding the message, epoch, and randomness into field elements. +2. Hashing these elements with Poseidon2 to produce a digest. +3. Interpreting the digest as a large integer and mapping it to a unique + vertex within the allowed top layers of the hypercube. +""" + +from __future__ import annotations + +from typing import List + +from lean_spec.subspecs.xmss.poseidon import ( + PROD_POSEIDON, + TEST_POSEIDON, + PoseidonXmss, +) +from lean_spec.subspecs.xmss.utils import int_to_base_p + +from ..koalabear import Fp, P +from .constants import ( + PROD_CONFIG, + TEST_CONFIG, + TWEAK_PREFIX_MESSAGE, + XmssConfig, +) +from .containers import Parameter, Randomness +from .hypercube import ( + hypercube_find_layer, + hypercube_part_size, + map_to_vertex, +) + + +class MessageHasher: + """An instance of the "Top Level" message hasher for a given config.""" + + def __init__(self, config: XmssConfig, poseidon_hasher: PoseidonXmss): + """Initializes the hasher with a specific parameter set.""" + self.config = config + self.poseidon = poseidon_hasher + + def encode_message(self, message: bytes) -> List[Fp]: + """ + Encodes a 32-byte message into a list of field elements. + + The message bytes are interpreted as a single little-endian integer, + which is then decomposed into its base-`P` representation. + """ + # Interpret the little-endian bytes as a single large integer. + acc = int.from_bytes(message, "little") + + # Decompose the integer into a list of field elements (base-P). + return int_to_base_p(acc, self.config.MSG_LEN_FE) + + def encode_epoch(self, epoch: int) -> List[Fp]: + """Encodes epoch and domain separator into a list of field elements.""" + # Combine the epoch and the message hash prefix into a single integer. + acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value + + # Decompose the integer into its base-P representation. + return int_to_base_p(acc, self.config.TWEAK_LEN_FE) + + def _map_into_hypercube_part(self, field_elements: List[Fp]) -> List[int]: + """ + Maps a list of field elements to a vertex in + the top `FINAL_LAYER` layers. + """ + # Get the config for this scheme. + config = self.config + + # Combine field elements into one large integer (big-endian, base-P). + acc = 0 + for fe in field_elements: + acc = acc * P + fe.value + + # Reduce this integer modulo the size of the target domain. + # + # The target domain is the set of all vertices + # in layers 0..FINAL_LAYER. + domain_size = hypercube_part_size( + config.BASE, config.DIMENSION, config.FINAL_LAYER + ) + acc %= domain_size + + # Find which layer the resulting index falls into, and its offset. + layer, offset = hypercube_find_layer( + config.BASE, config.DIMENSION, acc + ) + + # Map the offset within the layer to a unique vertex. + return map_to_vertex(config.BASE, config.DIMENSION, layer, offset) + + def apply( + self, + parameter: Parameter, + epoch: int, + randomness: Randomness, + message: bytes, + ) -> List[int]: + """ + Applies the full "Top Level" message hash procedure. + + This involves multiple invocations of Poseidon2, with the + combined output mapped into a specific region of the hypercube. + + Args: + parameter: The public parameter `P`. + epoch: The current epoch. + randomness: A random value `rho`. + message: The 32-byte message to be hashed. + + Returns: + A vertex in the hypercube, represented as a list + of `DIMENSION` ints. + """ + # Encode the message and epoch as field elements. + message_fe = self.encode_message(message) + epoch_fe = self.encode_epoch(epoch) + + # Iteratively call Poseidon2 to generate a long hash output. + poseidon_outputs: List[Fp] = [] + for i in range(self.config.POS_INVOCATIONS): + # Use the iteration number as a domain separator for + # each hash call. + iteration_separator = [Fp(value=i)] + + # The input is: rho || P || epoch || message || iteration. + combined_input = ( + randomness + + parameter + + epoch_fe + + message_fe + + iteration_separator + ) + + # Hash the combined input using Poseidon2 compression mode. + iteration_output = self.poseidon.compress( + combined_input, 24, self.config.POS_OUTPUT_LEN_PER_INV_FE + ) + poseidon_outputs.extend(iteration_output) + + # Map the final list of field elements into a hypercube vertex. + return self._map_into_hypercube_part(poseidon_outputs) + + +PROD_MESSAGE_HASHER = MessageHasher(PROD_CONFIG, PROD_POSEIDON) +"""An instance configured for production-level parameters.""" + +TEST_MESSAGE_HASHER = MessageHasher(TEST_CONFIG, TEST_POSEIDON) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/poseidon.py b/src/lean_spec/subspecs/xmss/poseidon.py new file mode 100644 index 00000000..7febf7cc --- /dev/null +++ b/src/lean_spec/subspecs/xmss/poseidon.py @@ -0,0 +1,159 @@ +"""Defines the Poseidon2 hash functions for the Generalized XMSS scheme.""" + +from __future__ import annotations + +from typing import List + +from lean_spec.subspecs.xmss.utils import int_to_base_p + +from ..koalabear import Fp +from ..poseidon2.permutation import ( + PARAMS_16, + PARAMS_24, + Poseidon2Params, + permute, +) +from .containers import HashDigest + + +class PoseidonXmss: + """An instance of the Poseidon2-based tweakable hash for a given config.""" + + def __init__(self, params16: Poseidon2Params, params24: Poseidon2Params): + """Initializes the hasher with specific Poseidon2 permutations.""" + self.params16 = params16 + self.params24 = params24 + + def compress( + self, input_vec: List[Fp], width: int, output_len: int + ) -> HashDigest: + """ + A low-level wrapper for Poseidon2 in compression mode. + + It computes: `Truncate(Permute(padded_input) + padded_input)`. + + Args: + input_vec: The input data to be hashed. + width: The state width of the Poseidon2 permutation (16 or 24). + output_len: The number of field elements in the output digest. + + Returns: + A hash digest of `output_len` field elements. + """ + # Check that the input vector is long enough to produce the output. + if len(input_vec) < output_len: + raise ValueError( + "Input vector is too short for requested output length." + ) + + # Select the correct permutation parameters based on the state width. + assert width in (16, 24), "Width must be 16 or 24" + params = self.params16 if width == 16 else self.params24 + + # Create a fixed-width buffer and copy the input, padding with zeros. + padded_input = [Fp(value=0)] * width + padded_input[: len(input_vec)] = input_vec + + # Apply the Poseidon2 permutation. + permuted_state = permute(padded_input, params) + + # Apply the feed-forward step, adding the input back element-wise. + final_state = [ + p + i for p, i in zip(permuted_state, padded_input, strict=True) + ] + + # Truncate the state to the desired output length and return. + return final_state[:output_len] + + def safe_domain_separator( + self, lengths: List[int], capacity_len: int + ) -> List[Fp]: + """ + Computes a domain separator for the sponge construction. + + This function hashes a list of length parameters to create a unique + "capacity value" that configures the sponge for a + specific hashing task. + + Args: + lengths: A list of integer parameters defining the hash context. + capacity_len: The desired length of the output capacity value. + + Returns: + A list of `capacity_len` field elements. + """ + # Pack the length parameters into a single large integer. + acc = 0 + for length in lengths: + acc = (acc << 32) | length + + # Decompose the integer into a list of 24 field elements for hashing. + # + # 24 is the fixed input width for this specific domain separation hash. + input_vec = int_to_base_p(acc, 24) + + # Compress the decomposed vector to produce the capacity value. + return self.compress(input_vec, 24, capacity_len) + + def sponge( + self, + input_vec: List[Fp], + capacity_value: List[Fp], + output_len: int, + width: int, + ) -> HashDigest: + """ + A low-level wrapper for Poseidon2 in sponge mode. + + Args: + input_vec: The input data of arbitrary length. + capacity_value: A domain-separating value. + output_len: The number of field elements in the output digest. + width: The width of the Poseidon2 permutation. + + Returns: + A hash digest of `output_len` field elements. + """ + # Ensure that the capacity value is not too long. + if len(capacity_value) >= width: + raise ValueError( + "Capacity length must be smaller than the state width." + ) + + # Get the correct permutation parameters. + assert width in (16, 24), "Width must be 16 or 24" + params = self.params16 if width == 16 else self.params24 + rate = width - len(capacity_value) + + # Pad the input vector to be an exact multiple of the rate. + num_extra = (rate - (len(input_vec) % rate)) % rate + padded_input = input_vec + [Fp(value=0)] * num_extra + + # Initialize the state with the capacity value. + state = [Fp(value=0)] * width + state[rate:] = capacity_value + + # Absorb the input in rate-sized chunks. + for i in range(0, len(padded_input), rate): + chunk = padded_input[i : i + rate] + # Add the chunk to the rate part of the state. + for j in range(rate): + state[j] += chunk[j] + # Apply the permutation. + state = permute(state, params) + + # Squeeze the output until enough elements have been generated. + output: HashDigest = [] + while len(output) < output_len: + output.extend(state[:rate]) + state = permute(state, params) + + # Truncate to the final output length and return. + return output[:output_len] + + +# An instance configured for production-level parameters. +PROD_POSEIDON = PoseidonXmss(PARAMS_16, PARAMS_24) + +# A lightweight instance for test environments. +TEST_POSEIDON = PoseidonXmss(PARAMS_16, PARAMS_24) diff --git a/src/lean_spec/subspecs/xmss/prf.py b/src/lean_spec/subspecs/xmss/prf.py new file mode 100644 index 00000000..b57910af --- /dev/null +++ b/src/lean_spec/subspecs/xmss/prf.py @@ -0,0 +1,152 @@ +""" +Defines the pseudorandom function (PRF) used in the signature scheme. + +PRF based on the SHAKE128 extendable-output function (XOF). + +The PRF is used to derive the secret starting points of the hash chains +for each epoch from a single master secret key. +""" + +from __future__ import annotations + +import hashlib +import os +from typing import List + +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import ( + PRF_KEY_LENGTH, + PROD_CONFIG, + TEST_CONFIG, + XmssConfig, +) +from lean_spec.subspecs.xmss.containers import PRFKey +from lean_spec.types.uint64 import Uint64 + +PRF_DOMAIN_SEP: bytes = bytes( + [ + 0xAE, + 0xAE, + 0x22, + 0xFF, + 0x00, + 0x01, + 0xFA, + 0xFF, + 0x21, + 0xAF, + 0x12, + 0x00, + 0x01, + 0x11, + 0xFF, + 0x00, + ] +) +""" +A 16-byte domain separator to ensure PRF outputs are unique to this context. + +This prevents any potential conflicts if the same underlying hash function +(SHAKE128) were used for other purposes in the system. +""" + +PRF_BYTES_PER_FE: int = 8 +""" +The number of bytes of SHAKE128 output used to generate one field element. + +We use 8 bytes (64 bits) of pseudorandom output, which is then reduced +modulo the 31-bit field prime `P`. This provides a significant statistical +safety margin to ensure the resulting field element is close to uniformly +random. +""" + + +class Prf: + """An instance of the SHAKE128-based PRF for a given config.""" + + def __init__(self, config: XmssConfig): + """Initializes the PRF with a specific parameter set.""" + self.config = config + + def key_gen(self) -> PRFKey: + """ + Generates a cryptographically secure random key for the PRF. + + This function sources randomness from the operating system's + entropy pool. + + Returns: + A new, randomly generated PRF key of `PRF_KEY_LENGTH` bytes. + """ + return os.urandom(PRF_KEY_LENGTH) + + def apply(self, key: PRFKey, epoch: int, chain_index: Uint64) -> List[Fp]: + """ + Applies the PRF to derive the secret values for a specific epoch + and chain. + + The function computes: + `SHAKE128(DOMAIN_SEP || key || epoch || chain_index)` + and interprets the output as a list of field elements. + + Args: + key: The secret PRF key. + epoch: The epoch number (a 32-bit unsigned integer). + chain_index: The index of the hash chain (a 64-bit uint). + + Returns: + A list of `DIMENSION` field elements, which are the secret starting + points for the hash chains of the specified epoch. + """ + # Get the config for this scheme. + config = self.config + + # Create a new SHAKE128 hash instance. + hasher = hashlib.shake_128() + + # Absorb the domain separator to contextualize the hash. + hasher.update(PRF_DOMAIN_SEP) + + # Absorb the secret key. + hasher.update(key) + + # Absorb the epoch, represented as a 4-byte big-endian integer. + # + # This ensures that each epoch produces a unique set of secrets. + hasher.update(epoch.to_bytes(4, "big")) + + # Absorb the chain index, as an 8-byte big-endian integer. + # + # This is used to derive a unique start value for each hash chain. + hasher.update(chain_index.to_bytes(8, "big")) + + # Determine the total number of bytes to extract from the SHAKE output. + # + # For key generation, we need one field element per chain. + num_bytes_to_read = PRF_BYTES_PER_FE * config.HASH_LEN_FE + prf_output_bytes = hasher.digest(num_bytes_to_read) + + # Convert the byte output into a list of field elements. + output_elements: List[Fp] = [] + for i in range(config.HASH_LEN_FE): + # Extract an 8-byte chunk for the current field element. + start = i * PRF_BYTES_PER_FE + end = start + PRF_BYTES_PER_FE + chunk = prf_output_bytes[start:end] + + # Convert the chunk to a large integer. + integer_value = int.from_bytes(chunk, "big") + + # Reduce the integer modulo the field prime `P`. + # + # The Fp constructor handles the modulo operation automatically. + output_elements.append(Fp(value=integer_value)) + + return output_elements + + +PROD_PRF = Prf(PROD_CONFIG) +"""An instance configured for production-level parameters.""" + +TEST_PRF = Prf(TEST_CONFIG) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/structures.py b/src/lean_spec/subspecs/xmss/structures.py deleted file mode 100644 index 306e9948..00000000 --- a/src/lean_spec/subspecs/xmss/structures.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Defines the data structures for the Generalized XMSS signature scheme.""" - -from typing import Annotated, List - -from pydantic import BaseModel, ConfigDict, Field - -from ..koalabear import Fp -from .constants import HASH_LEN_FE, PARAMETER_LEN, RAND_LEN_FE - -HashDigest = Annotated[ - List[Fp], Field(min_length=HASH_LEN_FE, max_length=HASH_LEN_FE) -] -""" -A type alias representing a hash digest. -""" - -Parameter = Annotated[ - List[Fp], Field(min_length=PARAMETER_LEN, max_length=PARAMETER_LEN) -] -""" -A type alias representing the public parameter `P`. -""" - -Randomness = Annotated[ - List[Fp], Field(min_length=RAND_LEN_FE, max_length=RAND_LEN_FE) -] -""" -A type alias representing the randomness `rho`. -""" - - -class HashTreeOpening(BaseModel): - """ - A Merkle authentication path. - - It contains a list of sibling nodes required to reconstruct the path - from a leaf node up to the Merkle root. - """ - - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - siblings: List[HashDigest] = Field( - ..., description="List of sibling hashes, from bottom to top." - ) - - -class PublicKey(BaseModel): - """The public key for the Generalized XMSS scheme.""" - - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - root: List[Fp] = Field(..., max_length=HASH_LEN_FE, min_length=HASH_LEN_FE) - parameter: Parameter = Field( - ..., max_length=PARAMETER_LEN, min_length=PARAMETER_LEN - ) - - -class Signature(BaseModel): - """A signature in the Generalized XMSS scheme.""" - - model_config = ConfigDict(frozen=True, arbitrary_types_allowed=True) - path: HashTreeOpening - rho: Randomness = Field( - ..., max_length=RAND_LEN_FE, min_length=RAND_LEN_FE - ) - hashes: List[HashDigest] - - -class SecretKey(BaseModel): - """ - Placeholder for the secret key. - - Note: The full secret key structure is not specified here as it is not - needed for verification. - """ - - pass diff --git a/src/lean_spec/subspecs/xmss/target_sum.py b/src/lean_spec/subspecs/xmss/target_sum.py new file mode 100644 index 00000000..4507f106 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/target_sum.py @@ -0,0 +1,70 @@ +""" +Implements the Top Level Target Sum Winternitz incomparable encoding scheme. + +This module provides the logic for converting a message hash into a valid +codeword for the one-time signature part of the scheme. It acts as a filter on +top of the message hash output. +""" + +from typing import List, Optional + +from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from .containers import Parameter, Randomness +from .message_hash import ( + PROD_MESSAGE_HASHER, + TEST_MESSAGE_HASHER, + MessageHasher, +) + + +class TargetSumEncoder: + """An instance of the Target Sum encoder for a given config.""" + + def __init__(self, config: XmssConfig, message_hasher: MessageHasher): + """Initializes the encoder with a specific parameter set.""" + self.config = config + self.message_hasher = message_hasher + + def encode( + self, parameter: Parameter, message: bytes, rho: Randomness, epoch: int + ) -> Optional[List[int]]: + """ + Encodes a message into a codeword if it meets the target sum criteria. + + This function first uses the message hash to map the input to a vertex + in the hypercube. It then checks if the sum of the vertex's coordinates + matches the scheme's `TARGET_SUM`. + + Args: + parameter: The public parameter `P`. + message: The message to encode. + rho: The randomness used for this encoding attempt. + epoch: The current epoch. + + Returns: + The codeword (a list of integers) if the sum is correct, + otherwise `None`. + """ + # Apply the message hash to get a potential codeword (a vertex). + codeword_candidate = self.message_hasher.apply( + parameter, epoch, rho, message + ) + + # Check if the candidate satisfies the target sum condition. + if sum(codeword_candidate) == self.config.TARGET_SUM: + # If the sum is correct, this is a valid codeword. + return codeword_candidate + else: + # If the sum does not match, this `rho` is invalid for + # this message. + # + # The caller (the `sign` function) will need to try again with new + # randomness. + return None + + +PROD_TARGET_SUM_ENCODER = TargetSumEncoder(PROD_CONFIG, PROD_MESSAGE_HASHER) +"""An instance configured for production-level parameters.""" + +TEST_TARGET_SUM_ENCODER = TargetSumEncoder(TEST_CONFIG, TEST_MESSAGE_HASHER) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/tweak_hash.py b/src/lean_spec/subspecs/xmss/tweak_hash.py new file mode 100644 index 00000000..298b77af --- /dev/null +++ b/src/lean_spec/subspecs/xmss/tweak_hash.py @@ -0,0 +1,203 @@ +"""Defines the Tweakable Hash function using Poseidon2.""" + +from __future__ import annotations + +from itertools import chain +from typing import List, Union + +from pydantic import Field + +from lean_spec.subspecs.xmss.poseidon import ( + PROD_POSEIDON, + TEST_POSEIDON, + PoseidonXmss, +) +from lean_spec.subspecs.xmss.utils import int_to_base_p +from lean_spec.types.base import StrictBaseModel + +from ..koalabear import Fp +from .constants import ( + PROD_CONFIG, + TEST_CONFIG, + TWEAK_PREFIX_CHAIN, + TWEAK_PREFIX_TREE, + XmssConfig, +) +from .containers import HashDigest, Parameter + + +class TreeTweak(StrictBaseModel): + """A tweak used for hashing nodes within the Merkle tree.""" + + level: int = Field(ge=0, description="The level in the Merkle tree.") + index: int = Field(ge=0, description="The node's index within that level.") + + +class ChainTweak(StrictBaseModel): + """A tweak used for hashing elements within a WOTS+ hash chain.""" + + epoch: int = Field(ge=0, description="The signature epoch.") + chain_index: int = Field(ge=0, description="The index of the hash chain.") + step: int = Field(ge=0, description="The step number within the chain.") + + +class TweakHasher: + """An instance of the Tweakable Hasher for a given config.""" + + def __init__(self, config: XmssConfig, poseidon_hasher: PoseidonXmss): + """Initializes the hasher with a specific parameter set.""" + self.config = config + self.poseidon = poseidon_hasher + + Tweak = Union[TreeTweak, ChainTweak] + """A type alias representing any valid tweak structure.""" + + def _encode_tweak(self, tweak: Tweak, length: int) -> List[Fp]: + """ + Encodes a tweak structure into a list of field elements for hashing. + + This function packs the tweak's integer components into a single large + integer, then performs a base-`P` decomposition to get field elements. + This ensures a unique and deterministic mapping from any tweak to a + format consumable by the hash function. + + The packing scheme is designed to be injective: + - `TreeTweak`: `(level << 40) | (index << 8) | TWEAK_PREFIX_TREE` + - `ChainTweak`: `(epoch << 24) | (chain_index << 16) + | (step << 8) | TWEAK_PREFIX_CHAIN` + + Args: + tweak: The `TreeTweak` or `ChainTweak` object. + length: The desired number of field elements in the output list. + + Returns: + A list of `length` field elements representing the encoded tweak. + """ + # Pack the tweak's integer fields into a single large integer. + # + # A unique prefix is included for domain separation. + if isinstance(tweak, TreeTweak): + acc = ( + (tweak.level << 40) + | (tweak.index << 8) + | TWEAK_PREFIX_TREE.value + ) + else: + acc = ( + (tweak.epoch << 24) + | (tweak.chain_index << 16) + | (tweak.step << 8) + | TWEAK_PREFIX_CHAIN.value + ) + + # Decompose the large integer `acc` into a list of field elements. + # This is a standard base-P decomposition. + # + # The number of elements is determined by the `length` parameter. + return int_to_base_p(acc, length) + + def apply( + self, + parameter: Parameter, + tweak: Tweak, + message_parts: List[HashDigest], + ) -> HashDigest: + """ + Applies the tweakable Poseidon2 hash function to a message. + + This function is the main entry point for all hashing operations. + It automatically selects the correct Poseidon2 mode + (compression or sponge) based on the number of message parts provided. + + Args: + parameter: The public parameter `P` for the hash function. + tweak: A `TreeTweak` or `ChainTweak` for domain separation. + message_parts: A list of one or more hash digests to be hashed. + + Returns: + A new hash digest of `HASH_LEN_FE` field elements. + """ + # Get the config for this scheme. + config = self.config + + # Encode the tweak structure into field elements. + encoded_tweak = self._encode_tweak(tweak, config.TWEAK_LEN_FE) + + if len(message_parts) == 1: + # Case 1: Hashing a single digest (used in hash chains). + # + # We use the efficient width-16 compression mode. + input_vec = parameter + encoded_tweak + message_parts[0] + return self.poseidon.compress(input_vec, 16, config.HASH_LEN_FE) + + elif len(message_parts) == 2: + # Case 2: Hashing two digests (used for Merkle tree nodes). + # + # We use the width-24 compression mode. + input_vec = ( + parameter + encoded_tweak + message_parts[0] + message_parts[1] + ) + return self.poseidon.compress(input_vec, 24, config.HASH_LEN_FE) + + else: + # Case 3: Hashing many digests (used for the Merkle tree leaf). + # + # We use the robust sponge mode. + # First, flatten the list of message parts into a single vector. + flattened_message = list(chain.from_iterable(message_parts)) + input_vec = parameter + encoded_tweak + flattened_message + + # Create a domain separator based on the dimensions of the input. + lengths = [ + config.PARAMETER_LEN, + config.TWEAK_LEN_FE, + config.DIMENSION, + config.HASH_LEN_FE, + ] + capacity_value = self.poseidon.safe_domain_separator( + lengths, config.CAPACITY + ) + + return self.poseidon.sponge( + input_vec, capacity_value, config.HASH_LEN_FE, 24 + ) + + def hash_chain( + self, + parameter: Parameter, + epoch: int, + chain_index: int, + start_step: int, + num_steps: int, + start_digest: HashDigest, + ) -> HashDigest: + """ + Performs repeated hashing to traverse a WOTS+ hash chain. + + Args: + parameter: The public parameter `P`. + epoch: The signature epoch. + chain_index: The index of the hash chain. + start_step: The starting step number in the chain. + num_steps: The number of hashing steps to perform. + start_digest: The digest to begin hashing from. + + Returns: + The final hash digest after `num_steps` applications. + """ + current_digest = start_digest + for i in range(num_steps): + # Create a unique tweak for the current position in the chain. + tweak = ChainTweak( + epoch=epoch, chain_index=chain_index, step=start_step + i + 1 + ) + # Apply the hash function. + current_digest = self.apply(parameter, tweak, [current_digest]) + return current_digest + + +PROD_TWEAK_HASHER = TweakHasher(PROD_CONFIG, PROD_POSEIDON) +"""An instance configured for production-level parameters.""" + +TEST_TWEAK_HASHER = TweakHasher(TEST_CONFIG, TEST_POSEIDON) +"""A lightweight instance for test environments.""" diff --git a/src/lean_spec/subspecs/xmss/utils.py b/src/lean_spec/subspecs/xmss/utils.py new file mode 100644 index 00000000..109e03e3 --- /dev/null +++ b/src/lean_spec/subspecs/xmss/utils.py @@ -0,0 +1,63 @@ +"""Utility functions for the XMSS signature scheme.""" + +import secrets +from typing import List + +from ..koalabear import Fp, P +from .constants import PROD_CONFIG, TEST_CONFIG, XmssConfig +from .containers import HashDigest, Parameter, Randomness + + +class Rand: + """An instance of the random data generator for a given config.""" + + def __init__(self, config: XmssConfig): + """Initializes the generator with a specific parameter set.""" + self.config = config + + def field_elements(self, length: int) -> List[Fp]: + """Generates a random list of field elements.""" + # For each element, generate a secure random integer + # in the range [0, P-1]. + return [Fp(value=secrets.randbelow(P)) for _ in range(length)] + + def parameter(self) -> Parameter: + """Generates a random public parameter.""" + return self.field_elements(self.config.PARAMETER_LEN) + + def domain(self) -> HashDigest: + """Generates a random hash digest.""" + return self.field_elements(self.config.HASH_LEN_FE) + + def rho(self) -> Randomness: + """Generates randomness `rho` for message encoding.""" + return self.field_elements(self.config.RAND_LEN_FE) + + +PROD_RAND = Rand(PROD_CONFIG) +"""An instance configured for production-level parameters.""" + +TEST_RAND = Rand(TEST_CONFIG) +"""A lightweight instance for test environments.""" + + +def int_to_base_p(value: int, num_limbs: int) -> List[Fp]: + """ + Decomposes a large integer into a list of base-P field elements. + + This function performs a standard base conversion, where each "digit" + is an element in the prime field F_p. + + Args: + value: The integer to decompose. + num_limbs: The desired number of output field elements (limbs). + + Returns: + A list of `num_limbs` field elements representing the integer. + """ + limbs: List[Fp] = [] + acc = value + for _ in range(num_limbs): + limbs.append(Fp(value=acc % P)) + acc //= P + return limbs diff --git a/tests/lean_spec/subspecs/xmss/test_hypercube.py b/tests/lean_spec/subspecs/xmss/test_hypercube.py new file mode 100644 index 00000000..dd49bf93 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_hypercube.py @@ -0,0 +1,218 @@ +"""Tests for the hypercube mathematical operations.""" + +import math +from functools import lru_cache +from typing import List + +import pytest + +from lean_spec.subspecs.xmss.hypercube import ( + MAX_DIMENSION, + get_layer_size, + hypercube_find_layer, + hypercube_part_size, + map_to_vertex, + prepare_layer_info, +) + + +def map_to_integer(w: int, v: int, d: int, a: List[int]) -> int: + """ + Maps a vertex `a` in layer `d` back to its integer index. + + This is a direct translation of the reference Rust implementation. + """ + if len(a) != v: + raise ValueError("Vertex length must equal dimension v.") + + layer_info_cache = prepare_layer_info(w) + x_curr = 0 + + # Initialize `d_curr` with the distance contribution + # of the last coordinate. + d_curr = (w - 1) - a[v - 1] + + # Loop backwards from the second-to-last coordinate to the first. + for i in range(v - 2, -1, -1): + ji = (w - 1) - a[i] + d_curr += ji + + # Dimension of the subproblem at this stage. + rem_dim = v - 1 - i + prev_dim_layer_info = layer_info_cache[rem_dim] + + # Calculate the start of the summation range for the subproblem. + j_start = max(0, d_curr - (w - 1) * rem_dim) + + # Add the sizes of all blocks that come before the current one. + sum_start = d_curr - (ji - 1) + sum_end = d_curr - j_start + x_curr += prev_dim_layer_info.sizes_sum_in_range(sum_start, sum_end) + + # At the end, the incrementally built distance must + # equal the target distance. + assert d_curr == d + + # The final accumulated value is the index. + return x_curr + + +@lru_cache(maxsize=None) +def _binom(n: int, k: int) -> int: + """A cached binomial coefficient calculator (n choose k).""" + if k < 0 or k > n: + return 0 + return math.comb(n, k) + + +def _nb(k: int, m: int, n: int) -> int: + """ + Computes the number of integer vectors of dimension `n` with entries in + [0, m] that sum to `k`. This is equivalent to the coefficient of x^k in + the polynomial (1 + x + ... + x^m)^n. + """ + total = 0 + for s in range(k // (m + 1) + 1): + term = _binom(n, s) * _binom(k - s * (m + 1) + n - 1, n - 1) + if s % 2 == 0: + total += term + else: + total -= term + return total + + +def _prepare_layer_sizes_by_binom(w: int) -> list[list[int]]: + """ + A reference implementation to calculate layer sizes using binomial + coefficients. It's slower but simpler, making it good for validation. + """ + all_layers: List[List[int]] = [[] for _ in range(MAX_DIMENSION + 1)] + for v in range(1, MAX_DIMENSION + 1): + max_distance = (w - 1) * v + layer_sizes: List[int] = [] + for d in range(max_distance + 1): + # The sum of coordinates is `k = v * (w - 1) - d`. + # We need to find the number of ways to write `k` as a sum of `v` + # integers, each between 0 and `w-1`. + coord_sum = v * (w - 1) - d + layer_sizes.append(_nb(coord_sum, w - 1, v)) + all_layers[v] = layer_sizes + return all_layers + + +def test_prepare_layer_sizes_against_reference() -> None: + """ + Validates the optimized `prepare_layer_info` against the slower, + math-based reference implementation for a range of `w` values. + """ + for w in range(2, 7): + expected_sizes_by_v = _prepare_layer_sizes_by_binom(w) + actual_info_by_v = prepare_layer_info(w) + + for v in range(1, MAX_DIMENSION + 1): + # Note: The reference implementation returns reversed layer sizes. + # Layer `d` in our spec corresponds to sum `k = v*(w-1) - d`. + expected_sizes_reordered = list(reversed(expected_sizes_by_v[v])) + actual_sizes = actual_info_by_v[v].sizes + assert expected_sizes_reordered == actual_sizes + + +@pytest.mark.parametrize( + "w, v, d, expected_size", + [ + (2, 1, 0, 1), + (2, 1, 1, 2), + (3, 2, 0, 1), + (3, 2, 1, 3), + (3, 2, 2, 6), + (3, 2, 3, 8), + (3, 2, 4, 9), + (2, 3, 0, 1), + (2, 3, 1, 4), + (2, 3, 2, 7), + (2, 3, 3, 8), + ], +) +def test_get_hypercube_part_size( + w: int, v: int, d: int, expected_size: int +) -> None: + """ + Tests `hypercube_part_size` with known values from the Rust tests. + """ + assert hypercube_part_size(w, v, d) == expected_size + + +def test_find_layer_boundaries() -> None: + """ + Tests `hypercube_find_layer` with specific boundary-crossing values. + """ + w, v = 3, 2 + # Layer sizes for (w=3, v=2) are [1, 2, 3, 2, 1] + # Prefix sums are [1, 3, 6, 8, 9] + + # x=0 is the 1st element, which is in layer 0. Remainder is 0. + assert hypercube_find_layer(w, v, 0) == (0, 0) + # x=1 is the 2nd element, which is the 1st element in layer 1. + # Remainder is 0. + assert hypercube_find_layer(w, v, 1) == (1, 0) + # x=2 is the 3rd element, which is the 2nd element in layer 1. + # Remainder is 1. + assert hypercube_find_layer(w, v, 2) == (1, 1) + # x=3 is the 4th element, which is the 1st element in layer 2. + # Remainder is 0. + assert hypercube_find_layer(w, v, 3) == (2, 0) + # x=5 is the 6th element, which is the third (last) element in layer 3. + # Remainder is 2. + assert hypercube_find_layer(w, v, 5) == (2, 2) + # x=6 is the 7th element, which is the first element in layer 3. + # Remainder is 0. + assert hypercube_find_layer(w, v, 6) == (3, 0) + # x=8 is the 9th element, which is the 1st element in layer 4. + # Remainder is 0. + assert hypercube_find_layer(w, v, 8) == (4, 0) + + +def test_map_to_vertex_roundtrip() -> None: + """ + Tests the map_to_vertex and map_to_integer roundtrip for a small case. + This test is slow and only checks a limited range. + """ + w, v, d = 4, 8, 20 + max_x = get_layer_size(w, v, d) + + # Iterate through every possible index in a specific layer + # and check roundtrip + for x in range( + min(max_x, 100) + ): # Capped at 100 iterations to keep test fast + vertex = map_to_vertex(w, v, d, x) + + # Check that the vertex sum corresponds to the correct layer + coord_sum = sum(vertex) + assert (w - 1) * v - coord_sum == d + + # Check that mapping back gives the original index + x_reconstructed = map_to_integer(w, v, d, vertex) + assert x_reconstructed == x + + +def test_big_map() -> None: + """ + Tests the full map_to_vertex -> map_to_integer roundtrip with a very + large number, exactly replicating the Rust reference test. + """ + w, v, d = 12, 40, 174 + x = 21790506781852242898091207809690042074412 + + # Map the integer to a vertex. + vertex_a = map_to_vertex(w, v, d, x) + + # Map the vertex back to an integer. + x_reconstructed = map_to_integer(w, v, d, vertex_a) + + # Map the reconstructed integer back to a vertex again. + vertex_b = map_to_vertex(w, v, d, x_reconstructed) + + # Assert that both steps of the roundtrip were successful. + assert x == x_reconstructed + assert vertex_a == vertex_b diff --git a/tests/lean_spec/subspecs/xmss/test_interface.py b/tests/lean_spec/subspecs/xmss/test_interface.py new file mode 100644 index 00000000..5d7ea12d --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_interface.py @@ -0,0 +1,78 @@ +""" +End-to-end tests for the Generalized XMSS signature scheme. +""" + +import pytest + +from lean_spec.subspecs.xmss.interface import ( + TEST_SIGNATURE_SCHEME, + GeneralizedXmssScheme, +) + + +def _test_correctness_roundtrip( + scheme: GeneralizedXmssScheme, + activation_epoch: int, + num_active_epochs: int, +) -> None: + """ + A helper to perform a full key_gen -> sign -> verify roundtrip. + + It generates a key pair, signs a message at a specific epoch, and + verifies the signature. It also checks that verification fails for + an incorrect message or epoch. + """ + # KEY GENERATION + # + # Generate a new key pair for the specified active range. + pk, sk = scheme.key_gen(activation_epoch, num_active_epochs) + + # SIGN & VERIFY + # + # Pick a sample epoch within the active range to test signing. + test_epoch = activation_epoch + num_active_epochs // 2 + message = b"\x42" * scheme.config.MESSAGE_LENGTH + + # Sign the message at the chosen epoch. + # + # This might take a moment as it may try multiple `rho` values. + signature = scheme.sign(sk, test_epoch, message) + + # Verification of the valid signature must succeed. + is_valid = scheme.verify(pk, test_epoch, message, signature) + assert is_valid, "Verification of a valid signature failed" + + # TEST INVALID CASES + # + # Verification must fail if the message is tampered with. + tampered_message = b"\x43" * scheme.config.MESSAGE_LENGTH + is_invalid_msg = scheme.verify(pk, test_epoch, tampered_message, signature) + assert not is_invalid_msg, "Verification succeeded for a tampered message" + + # Verification must fail if the epoch is incorrect. + if num_active_epochs > 1: + wrong_epoch = test_epoch + 1 + is_invalid_epoch = scheme.verify(pk, wrong_epoch, message, signature) + assert not is_invalid_epoch, ( + "Verification succeeded for an incorrect epoch" + ) + + +@pytest.mark.parametrize( + "activation_epoch, num_active_epochs, description", + [ + (10, 4, "Standard case with a short, active lifetime"), + (0, 8, "Lifetime starting at epoch 0"), + (20, 1, "Lifetime with only a single active epoch"), + (7, 5, "Lifetime starting at an odd-numbered epoch"), + ], +) +def test_signature_scheme_correctness( + activation_epoch: int, num_active_epochs: int, description: str +) -> None: + """Runs an end-to-end test of the signature scheme.""" + _test_correctness_roundtrip( + scheme=TEST_SIGNATURE_SCHEME, + activation_epoch=activation_epoch, + num_active_epochs=num_active_epochs, + ) diff --git a/tests/lean_spec/subspecs/xmss/test_merkle_tree.py b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py new file mode 100644 index 00000000..a0597253 --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_merkle_tree.py @@ -0,0 +1,94 @@ +"""Tests for the sparse Merkle tree implementation.""" + +import pytest + +from lean_spec.subspecs.xmss.containers import HashDigest +from lean_spec.subspecs.xmss.merkle_tree import ( + PROD_MERKLE_TREE, + MerkleTree, +) +from lean_spec.subspecs.xmss.tweak_hash import ( + TreeTweak, +) + + +def _run_commit_open_verify_roundtrip( + merkle_tree: MerkleTree, + num_leaves: int, + depth: int, + start_index: int, + leaf_parts_len: int, +) -> None: + """ + A helper function to perform a full Merkle tree roundtrip test. + + The process is as follows: + 1. Generate random leaf data. + 2. Hash the leaves to create layer 0 of the tree. + 3. Build the full Merkle tree and get its root (commit). + 4. For each leaf, generate an authentication path (open). + 5. Verify that each path is valid for its corresponding leaf and root. + + Args: + num_leaves: The number of active leaves in the tree. + start_index: The starting index of the first active leaf. + leaf_parts_len: The number of digests that constitute a single leaf. + """ + # SETUP: Generate a random parameter and the raw leaf data. + parameter = merkle_tree.rand.parameter() + leaves: list[list[HashDigest]] = [ + [merkle_tree.rand.domain() for _ in range(leaf_parts_len)] + for _ in range(num_leaves) + ] + + # HASH LEAVES: Compute the layer 0 nodes by hashing the leaf parts. + leaf_hashes: list[HashDigest] = [ + merkle_tree.hasher.apply( + parameter, + TreeTweak(level=0, index=start_index + i), + leaf_parts, + ) + for i, leaf_parts in enumerate(leaves) + ] + + # COMMIT: Build the Merkle tree from the leaf hashes. + tree = merkle_tree.build(depth, start_index, parameter, leaf_hashes) + root = merkle_tree.root(tree) + + # OPEN & VERIFY: For each leaf, generate and verify its path. + for i, leaf_parts in enumerate(leaves): + position = start_index + i + opening = merkle_tree.path(tree, position) + is_valid = merkle_tree.verify_path( + parameter, root, position, leaf_parts, opening + ) + assert is_valid, f"Verification failed for leaf at position {position}" + + +@pytest.mark.parametrize( + "num_leaves, depth, start_index, leaf_parts_len, description", + [ + (16, 4, 0, 3, "Full tree (depth 4)"), + (12, 5, 0, 5, "Half tree, left-aligned (depth 5)"), + (16, 5, 16, 2, "Half tree, right-aligned (depth 5)"), + (22, 6, 13, 3, "Sparse, non-aligned tree (depth 6)"), + (2, 2, 2, 6, "Half tree, right-aligned (small)"), + (1, 1, 0, 1, "Tree with a single leaf at the start"), + (1, 1, 1, 1, "Tree with a single leaf at an odd index"), + (16, 5, 7, 2, "Small sparse tree starting at an odd index"), + ], +) +def test_commit_open_verify_roundtrip( + num_leaves: int, + depth: int, + start_index: int, + leaf_parts_len: int, + description: str, +) -> None: + """Tests the Merkle tree logic for various configurations.""" + # Ensure the test case parameters are valid for the specified tree depth. + assert start_index + num_leaves <= (1 << depth) + + _run_commit_open_verify_roundtrip( + PROD_MERKLE_TREE, num_leaves, depth, start_index, leaf_parts_len + ) diff --git a/tests/lean_spec/subspecs/xmss/test_message_hash.py b/tests/lean_spec/subspecs/xmss/test_message_hash.py new file mode 100644 index 00000000..18476c3b --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_message_hash.py @@ -0,0 +1,95 @@ +""" +Tests for the "Top Level" message hashing and encoding logic. +""" + +from lean_spec.subspecs.koalabear import Fp +from lean_spec.subspecs.xmss.constants import ( + TEST_CONFIG, + TWEAK_PREFIX_MESSAGE, +) +from lean_spec.subspecs.xmss.message_hash import ( + TEST_MESSAGE_HASHER, +) +from lean_spec.subspecs.xmss.utils import TEST_RAND, int_to_base_p + + +def test_encode_message() -> None: + """Tests `encode_message` with various message patterns.""" + config = TEST_CONFIG + hasher = TEST_MESSAGE_HASHER + + # All-zero message + msg_zeros = b"\x00" * config.MESSAGE_LENGTH + encoded_zeros = hasher.encode_message(msg_zeros) + assert len(encoded_zeros) == config.MSG_LEN_FE + assert all(fe.value == 0 for fe in encoded_zeros) + + # All-max message (0xff) + msg_max = b"\xff" * config.MESSAGE_LENGTH + acc = int.from_bytes(msg_max, "little") + expected_max = int_to_base_p(acc, config.MSG_LEN_FE) + assert hasher.encode_message(msg_max) == expected_max + + +def test_encode_epoch() -> None: + """ + Tests `encode_epoch` for correctness and injectivity. + """ + hasher = TEST_MESSAGE_HASHER + config = TEST_CONFIG + + # Test specific values from the Rust reference tests. + test_epochs = [0, 42, 2**32 - 1] + for epoch in test_epochs: + acc = (epoch << 8) | TWEAK_PREFIX_MESSAGE.value + expected = int_to_base_p(acc, config.TWEAK_LEN_FE) + assert hasher.encode_epoch(epoch) == expected + + # Test for injectivity. It is highly unlikely for a collision to occur + # with a few random samples if the encoding is injective. + num_trials = 1000 + seen_encodings: set[tuple[Fp, ...]] = set() + for i in range(num_trials): + encoding = tuple(hasher.encode_epoch(i)) + assert encoding not in seen_encodings + seen_encodings.add(encoding) + + +def test_apply_output_is_in_correct_hypercube_part() -> None: + """ + Tests that the output of `apply` is a valid vertex that lies within + the top `FINAL_LAYER` layers of the hypercube. + """ + config = TEST_CONFIG + hasher = TEST_MESSAGE_HASHER + rand = TEST_RAND + + # Setup with random inputs. + parameter = rand.parameter() + epoch = 313 + randomness = rand.rho() + message = b"\xaa" * config.MESSAGE_LENGTH + + # Call the message hash function. + vertex = hasher.apply(parameter, epoch, randomness, message) + + # Verify the properties of the output vertex. + # + # The length of the vertex must be equal to the hypercube's dimension. + assert len(vertex) == config.DIMENSION + # Each coordinate must be smaller than the base `w`. + assert all(0 <= coord < config.BASE for coord in vertex) + + # Check that the vertex lies in the correct set of layers. + # + # By definition, a vertex is in layer `d` if `d = v*(w-1) - sum(coords)`. + # + # We require `d <= FINAL_LAYER`. + # + # This is equivalent to `sum(coords) >= v*(w-1) - FINAL_LAYER`. + coord_sum = sum(vertex) + min_required_sum = ( + config.BASE - 1 + ) * config.DIMENSION - config.FINAL_LAYER + + assert coord_sum >= min_required_sum, "Vertex is not in the top layers" diff --git a/tests/lean_spec/subspecs/xmss/test_prf.py b/tests/lean_spec/subspecs/xmss/test_prf.py new file mode 100644 index 00000000..519f360c --- /dev/null +++ b/tests/lean_spec/subspecs/xmss/test_prf.py @@ -0,0 +1,73 @@ +"""Tests for the SHAKE128-based pseudorandom function (PRF).""" + +from lean_spec.subspecs.xmss.constants import ( + PRF_KEY_LENGTH, + TEST_CONFIG, +) +from lean_spec.subspecs.xmss.prf import TEST_PRF + + +def test_key_gen_is_random() -> None: + """ + Performs a sanity check on `key_gen` to ensure it's not deterministic + or producing trivial outputs. + + This test mirrors the logic from the reference Rust implementation. + """ + prf = TEST_PRF + + # Check that the key has the correct length. + key = prf.key_gen() + assert len(key) == PRF_KEY_LENGTH + + # Generate multiple keys and ensure they are not all identical. + # + # This is a basic check to ensure we are getting fresh randomness. + num_trials = 10 + keys = {prf.key_gen() for _ in range(num_trials)} + assert len(keys) == num_trials + + # Check that the keys are not filled with a single repeated byte. + # + # It is astronomically unlikely for a secure random generator to produce + # such a key, so this is a good health check. + all_same_count = 0 + for _ in range(num_trials): + key = prf.key_gen() + # A set will have size 1 if all elements are the same. + if len(set(key)) == 1: + all_same_count += 1 + assert all_same_count < num_trials, "key_gen produced non-random keys" + + +def test_apply_is_sensitive_to_inputs() -> None: + """ + Tests that changing any input to `apply` results in a different output. + + This confirms that all parts of the input (key, epoch, chain_index) are + being correctly absorbed by the hash function. + """ + prf = TEST_PRF + config = TEST_CONFIG + + # Generate a baseline output with a set of initial inputs. + key1 = b"\x11" * PRF_KEY_LENGTH + epoch1 = 10 + chain_index1 = 20 + baseline_output = prf.apply(key1, epoch1, chain_index1) + assert len(baseline_output) == config.HASH_LEN_FE + + # Test sensitivity to the key. + key2 = b"\x22" * PRF_KEY_LENGTH + output_key_changed = prf.apply(key2, epoch1, chain_index1) + assert baseline_output != output_key_changed + + # Test sensitivity to the epoch. + epoch2 = 11 + output_epoch_changed = prf.apply(key1, epoch2, chain_index1) + assert baseline_output != output_epoch_changed + + # Test sensitivity to the chain_index. + chain_index2 = 21 + output_index_changed = prf.apply(key1, epoch1, chain_index2) + assert baseline_output != output_index_changed