diff --git a/src/lean_spec/subspecs/xmss/hypercube.py b/src/lean_spec/subspecs/xmss/hypercube.py index f08b6c27..a39d49d0 100644 --- a/src/lean_spec/subspecs/xmss/hypercube.py +++ b/src/lean_spec/subspecs/xmss/hypercube.py @@ -1,16 +1,28 @@ """ -Implements the mathematical operations for hypercube layers. - -This module provides the necessary functions to work with vertices in a -v-dimensional hypercube with coordinates in the range [0, w-1]. A key concept -is the partitioning of the hypercube's vertices into "layers". A vertex belongs -to layer `d`, where `d` is its distance from the sink vertex `(w-1, ..., w-1)`. - -The core functionalities are: -1. **Precomputation and Caching**: Computes and caches the sizes - of each layer for different hypercube configurations (`w` and `v`). -2. **Mapping**: Provides bijective mappings between an integer index within a - layer and the unique vertex (a list of coordinates) it represents. +Implements the mathematical operations for hypercube-based encodings. + +This module provides the core algorithms for working with vertices in a +v-dimensional hypercube where coordinates are in the range [0, w-1]. This is a +foundational component for the "Top of the Hypercube" signature schemes. + +Core Concepts +------------- +1. **Hypercube (`[w]^v`)**: The set of all possible coordinate vectors for a + signature. The dimension `v` corresponds to the number of hash chains (and + thus signature size), and the base `w` corresponds to the length of each + hash chain. + +2. **Layers (`d`)**: The hypercube's vertices are partitioned into "layers" based + on their **verification cost**. The layer `d` of a vertex is its distance + from the sink vertex `(w-1, ..., w-1)`, calculated as: + `d = (w-1)*v - sum(coordinates)`. + A smaller `d` means a lower verification cost, making these "top layers" + the most desirable for encoding messages into. + +3. **Mapping Problem**: The central challenge is to deterministically and + efficiently map a single integer (derived from a message hash) to a unique + coordinate vector within a specific layer or a set of top layers. This + module provides the building blocks for that mapping. This logic is a direct translation of the algorithms described in the paper "At the top of the hypercube" (eprint 2025/889) @@ -20,7 +32,9 @@ from __future__ import annotations import bisect +import math from functools import lru_cache +from itertools import accumulate from typing import List, Tuple from pydantic import BaseModel, ConfigDict @@ -31,112 +45,117 @@ class LayerInfo(BaseModel): """ - Stores the precomputed sizes and cumulative sums for a - hypercube's layers. + A data structure to store precomputed sizes and cumulative sums for the + layers of a single hypercube configuration (fixed `w` and `v`). + + This object makes subsequent calculations, like finding the total size of a + range of layers, highly efficient. """ model_config = ConfigDict(frozen=True) sizes: List[int] - """The number of vertices in each layer `d`.""" + """A list where `sizes[d]` is the number of vertices in layer `d`.""" prefix_sums: List[int] """ - The cumulative number of vertices up to and including layer `d`. + A list where `prefix_sums[d]` is the cumulative number of vertices from + layer 0 up to and including layer `d`. - `prefix_sums[d] = sizes[0] + ... + sizes[d]`. + Mathematically: `prefix_sums[d] = sizes[0] + ... + sizes[d]`. """ def sizes_sum_in_range(self, start: int, end: int) -> int: - """Calculates the sum of sizes in an inclusive range [start, end].""" + """ + Calculates the sum of `sizes` in an inclusive range [start, end]. + + This is an O(1) operation thanks to the precomputed `prefix_sums`. + """ + # If the range is invalid, the sum is zero. if start > end: return 0 + # If the range starts from the beginning, the sum is simply the + # prefix sum at the end of the range. if start == 0: return self.prefix_sums[end] + # Otherwise, the sum is the difference between the prefix sum at the + # end and the prefix sum of the elements just before the start. else: return self.prefix_sums[end] - self.prefix_sums[start - 1] -@lru_cache(maxsize=None) -def prepare_layer_info(w: int) -> List[LayerInfo]: +def _calculate_layer_size(w: int, v: int, d: int) -> int: """ - Precomputes and caches the number of vertices in each layer of a hypercube. + Calculates a hypercube layer's size using a direct combinatorial formula. + + This function answers the question: "How many unique coordinate vectors + (vertices) exist in a specific layer `d`?" - It calculates the size of every layer for hypercubes with a given - base `w` (where coordinates are in `[0, w-1]`) for all dimensions - `v` up to `MAX_DIMENSION`. + The problem is mathematically equivalent to finding the number of integer + solutions to the equation: + x_1 + x_2 + ... + x_v = k + subject to the constraint that each coordinate `x_i` is in the range + `0 <= x_i <= w-1`. The required sum `k` is derived from the layer's + distance `d`. - This precomputation is based on the recurrence relation from - Lemma 8 of the paper "At the top of the hypercube" (eprint 2025/889). + The solution uses two key combinatorial techniques: + 1. **Stars and Bars**: To find the number of solutions without the + upper-bound constraint (`x_i <= w-1`). + + 2. **Inclusion-Exclusion Principle**: To correct the count by systematically + adding and subtracting the solutions that violate the upper bound. Args: - w: The base of the hypercube. + w: The hypercube base (coordinates are `0` to `w-1`). + v: The hypercube dimension (number of coordinates). + d: The target layer's distance from the sink vertex `(w-1, ..., w-1)`. Returns: - A list where the element at index `v` is a `LayerInfo` object - containing the layer sizes for a `v`-dimensional hypercube. + The total number of vertices in the specified layer. """ - # Initialize a list to store the results for each dimension `v`. + # A vertex is in layer `d` if its coordinates sum to `k = v * (w - 1) - d`. # - # Index 0 is unused to allow for direct indexing, e.g., `all_info[v]`. - all_info = [LayerInfo(sizes=[], prefix_sums=[])] * (MAX_DIMENSION + 1) + # This `coord_sum` is the `k` in our combinatorial problem. + coord_sum = v * (w - 1) - d - # BASE CASE - # - # For a 1-dimensional hypercube (v=1), which is just a line of `w` points - # with coordinates [0], [1], ..., [w-1]. + # This is the compact implementation of the inclusion-exclusion principle. # - # The distance `d` from the sink `[w-1]` is simply `(w-1) - coordinate`. + # It directly calculates the sum: Σ (-1)^s * C(v,s) * C(k - s*w + v-1, v-1) + return sum( + ((-1) ** s) * math.comb(v, s) * math.comb(coord_sum - s * w + v - 1, v - 1) + for s in range(coord_sum // w + 1) + ) + + +@lru_cache(maxsize=None) +def prepare_layer_info(w: int) -> List[LayerInfo]: + """ + Precomputes and caches layer information using a direct combinatorial formula. - # Each of the `w` possible layers contains exactly one vertex. - dim1_sizes = [1] * w - # The prefix sums (cumulative sizes) are therefore just [1, 2, 3, ..., w]. - dim1_prefix_sums = list(range(1, w + 1)) - # Store the result for v=1, which will seed the inductive step. - all_info[1] = LayerInfo(sizes=dim1_sizes, prefix_sums=dim1_prefix_sums) + For each dimension `v` up to `MAX_DIMENSION`, this function calculates the + size of every layer `d` directly, without relying on the results from + smaller dimensions. While less computationally efficient than the recursive + method, this implementation is more concise and mathematically direct. + + Args: + w: The base of the hypercube. + + Returns: + A list where `list[v]` is a `LayerInfo` object for a `v`-dim hypercube. + """ + all_info = [LayerInfo(sizes=[], prefix_sums=[])] * (MAX_DIMENSION + 1) - # Now, build the layer info for all higher dimensions up to the maximum. - for v in range(2, MAX_DIMENSION + 1): + for v in range(1, MAX_DIMENSION + 1): # The maximum possible distance `d` in a v-dimensional hypercube. max_d = (w - 1) * v - # Retrieve the already-computed data for the previous dimension (v-1). - prev_layer_info = all_info[v - 1] - - # This list will store the computed size of each layer `d` - # for dimension `v`. - current_sizes: List[int] = [] - for d in range(max_d + 1): - # Implements the recurrence l_d(v) = Σ l_{d-j}(v-1) from the paper. - # `j` is one coordinate's distance contribution. - - # Calculate the valid range [j_min, j_max] for `j`. - j_min = max(0, d - (w - 1) * (v - 1)) - j_max = min(w - 1, d) - - # Translate the sum over `j` to an index range `k`, - # where k = d - j. - # - # This allows for an efficient lookup using prefix sums. - k_min = d - j_max - k_max = d - j_min - - # Calculate the sum using the precomputed prefix sums - # from the previous dimension's `LayerInfo`. - layer_size = prev_layer_info.sizes_sum_in_range(k_min, k_max) - current_sizes.append(layer_size) - - # After computing all layer sizes for dimension `v`, we compute their - # prefix sums. - # - # This is needed for the *next* iteration (for dimension v+1). - current_prefix_sums: List[int] = [] - current_sum = 0 - for size in current_sizes: - current_sum += size - current_prefix_sums.append(current_sum) + + # Directly compute the size of each layer using the helper function. + sizes = [_calculate_layer_size(w, v, d) for d in range(max_d + 1)] + + # Compute the cumulative sums from the list of sizes. + prefix_sums = list(accumulate(sizes)) # Store the complete layer info for the current dimension `v`. - all_info[v] = LayerInfo(sizes=current_sizes, prefix_sums=current_prefix_sums) + all_info[v] = LayerInfo(sizes=sizes, prefix_sums=prefix_sums) - # Return the complete table of layer information for the given base `w`. return all_info @@ -152,16 +171,18 @@ def hypercube_part_size(w: int, v: int, d: int) -> int: def hypercube_find_layer(w: int, v: int, x: int) -> Tuple[int, int]: """ - Given a global index `x`, finds the layer `d` it belongs to and its - local index (`remainder`) within that layer. + Given a global index `x`, finds its layer `d` and local offset `remainder`. + + This function determines which "layer bucket" a global index falls into. Args: w: The hypercube base. v: The hypercube dimension. - x: The global index of a vertex (from 0 to w**v - 1). + x: The global index of a vertex, from 0 to (w**v - 1). Returns: - A tuple `(d, remainder)`. + A tuple `(d, remainder)`, where `d` is the layer and `remainder` is + the local index (offset) of the vertex within that layer. """ prefix_sums = prepare_layer_info(w)[v].prefix_sums # Use binary search to efficiently find the correct layer. @@ -187,17 +208,16 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]: """ Maps an integer index `x` to a unique vertex in a specific hypercube layer. - This function provides a bijective mapping from an integer `x` - (derived from a hash) to a unique list of `v` coordinates, - `[a_0, ..., a_{v-1}]`. - - The algorithm works iteratively, determining one coordinate at a time. + This function provides a bijective mapping from a location `(d, x)` to a + unique coordinate vector `[a_0, ..., a_{v-1}]`. The algorithm works + iteratively, determining one coordinate at a time by reducing the problem + to a smaller subproblem in a hypercube of one less dimension. Args: w: The hypercube base (coordinates are in `[0, w-1]`). v: The hypercube dimension (the number of coordinates). d: The target layer, defined by its distance from the sink vertex. - x: The integer index within the layer `d`, must be `0 <= x < size(d)`. + x: The integer index (offset) within layer `d`. Must be `0 <= x < size(d)`. Returns: A list of `v` integers representing the coordinates of the vertex. @@ -218,8 +238,7 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]: dim_remaining = v - i prev_dim_layer_info = layer_info_cache[dim_remaining] - # This loop finds which block of sub-hypercubes the index `x_curr` - # falls into. + # This loop finds which block of sub-hypercubes the index `x_curr` falls into. # # It skips over full blocks by subtracting their size # from `x_curr` until the correct one is found. @@ -230,7 +249,8 @@ def map_to_vertex(w: int, v: int, d: int, x: int) -> List[int]: if x_curr >= count: x_curr -= count else: - ji = j # Found the correct block. + # Found the correct block. + ji = j break if ji == -1: