Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions src/lean_spec/subspecs/poseidon2/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,6 @@ def check_lengths(self) -> "Poseidon2Params":
[Fp(value=3), Fp(value=1), Fp(value=1), Fp(value=2)],
]

# =================================================================
# Linear Layers
# =================================================================


def _apply_m4(chunk: List[Fp]) -> List[Fp]:
"""
Expand Down Expand Up @@ -213,17 +209,31 @@ def internal_linear_layer(state: List[Fp], params: Poseidon2Params) -> List[Fp]:
Returns:
The state vector after applying the internal linear layer.
"""
# Calculate the sum of all elements in the state vector.
s_sum = sum(state, Fp(value=0))
# For each element s_i, compute s_i' = d_i * s_i + sum(s).
# This is the efficient computation of (J + D)s.
new_state = [s * d + s_sum for s, d in zip(state, params.internal_diag_vectors, strict=False)]
return new_state
width = params.width
diag_vector = params.internal_diag_vectors

# Construct the M_I matrix explicitly.
#
# It has dimensions width x width.
m_i = [[Fp(value=1) for _ in range(width)] for _ in range(width)]

# =================================================================
# Core Permutation
# =================================================================
# Add the diagonal part (D) to the all-ones matrix (J)
#
# The result is M_I = J + D.
for i in range(width):
m_i[i][i] += diag_vector[i]

# Perform standard matrix-vector multiplication: new_state = m_i * state
#
# Initialize the result vector with zeros.
new_state = [Fp(value=0)] * width

# For each row in the matrix, calculate the dot product of that row with the state vector.
for i in range(width):
for j in range(width):
new_state[i] += m_i[i][j] * state[j]

return new_state


def permute(state: List[Fp], params: Poseidon2Params) -> List[Fp]:
Expand Down
Loading