Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/modexp/Modexp.sol
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ library Modexp {
bytes memory base,
bytes memory exponent,
bytes memory modulus
) internal view returns (bytes memory result) {
) internal pure returns (bytes memory result) {
if (modulus.length == 0) return new bytes(0);

// Check if modulus is odd (last byte has bit 0 set)
Expand Down
116 changes: 40 additions & 76 deletions src/modexp/ModexpBarrett.sol
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ library ModexpBarrett {
bytes memory base,
bytes memory exponent,
bytes memory modulus
) internal view returns (bytes memory result) {
) internal pure returns (bytes memory result) {
uint256 modLen = modulus.length;
if (modLen == 0) return new bytes(0);

Expand Down Expand Up @@ -46,10 +46,10 @@ library ModexpBarrett {
uint256 k = (modLen + 31) / 32;

uint256[] memory n = _bytesToLimbs(modulus, k);
uint256[] memory a = _reduceBase(base, modulus, k);
uint256[] memory a = _reduceBase(base, n, k);

// Barrett constant: mu = floor(2^(512k) / n), has k+1 limbs
uint256[] memory mu = _computeBarrettConstant(n, k, modulus);
uint256[] memory mu = _computeBarrettConstant(n, k);

// one = 1 as k-limb number (used as initial accumulator)
uint256[] memory r = new uint256[](k);
Expand Down Expand Up @@ -144,83 +144,31 @@ library ModexpBarrett {
}
}

// ── Precompile wrapper ────────────────────────────────────────────

function _callPrecompile(bytes memory b, bytes memory e, bytes memory m)
private view returns (bytes memory result)
{
uint256 modLen = m.length;
result = new bytes(modLen);
bytes memory input = abi.encodePacked(
uint256(b.length), uint256(e.length), uint256(modLen), b, e, m
);
assembly {
let ok := staticcall(gas(), 0x05, add(input, 0x20), mload(input), add(result, 0x20), modLen)
if iszero(ok) { revert(0, 0) }
}
}

function _reduceBase(bytes memory base, bytes memory modulus, uint256 k)
private view returns (uint256[] memory)
/// @dev Reduces base mod n via schoolbook division remainder.
function _reduceBase(bytes memory base, uint256[] memory n, uint256 k)
private pure returns (uint256[] memory)
{
return _bytesToLimbs(_callPrecompile(base, hex"01", modulus), k);
}

function _uint256ToMinBytes(uint256 val) private pure returns (bytes memory) {
if (val == 0) return hex"00";
uint256 byteLen = 0;
uint256 tmp = val;
while (tmp > 0) {
byteLen++;
tmp >>= 8;
}
bytes memory result = new bytes(byteLen);
unchecked {
for (uint256 i = 0; i < byteLen; i++) {
result[byteLen - 1 - i] = bytes1(uint8(val));
val >>= 8;
}
}
return result;
uint256 baseLen = base.length;
if (baseLen == 0) return new uint256[](k);
uint256 baseK = (baseLen + 31) / 32;
if (baseK < k) baseK = k;
uint256[] memory baseLimbs = _bytesToLimbs(base, baseK);
(, uint256[] memory rem) = _schoolbookDiv(baseLimbs, baseK, n, k);
return rem;
}

// ── Barrett constant computation ──────────────────────────────────

/// @dev Computes mu = floor(2^(512k) / n).
function _computeBarrettConstant(uint256[] memory n, uint256 k, bytes memory modulus)
private view returns (uint256[] memory)
function _computeBarrettConstant(uint256[] memory n, uint256 k)
private pure returns (uint256[] memory)
{
// r = 2^(512k) mod n via precompile
uint256 expVal = 512 * k;
bytes memory expBytes = _uint256ToMinBytes(expVal);
bytes memory base2 = hex"02";
uint256[] memory r = _bytesToLimbs(_callPrecompile(base2, expBytes, modulus), k);

// dividend = 2^(512k) - r, which is exactly divisible by n
// dividend has 2k+1 limbs: limbs 0..k-1 = two's complement of r, limb 2k = 1
// Construct 2^(512k) as a (2k+1)-limb number: all zeros except limb[2k] = 1
uint256 dLen = 2 * k + 1;
uint256[] memory dividend = new uint256[](dLen);

// Compute -r mod 2^(256k), i.e., two's complement
assembly {
let divP := add(dividend, 0x20)
let rP := add(r, 0x20)
let borrow := 0
for { let i := 0 } lt(i, k) { i := add(i, 1) } {
let ri := mload(add(rP, mul(i, 0x20)))
let d := sub(sub(0, ri), borrow)
borrow := or(gt(ri, 0), and(iszero(ri), borrow))
mstore(add(divP, mul(i, 0x20)), d)
}
for { let i := k } lt(i, mul(2, k)) { i := add(i, 1) } {
let d := sub(0, borrow)
mstore(add(divP, mul(i, 0x20)), d)
}
mstore(add(divP, mul(mul(2, k), 0x20)), sub(1, borrow))
}

// mu = dividend / n via schoolbook division
return _schoolbookDiv(dividend, dLen, n, k);
dividend[2 * k] = 1;
(uint256[] memory mu,) = _schoolbookDiv(dividend, dLen, n, k);
return mu;
}

// ── Division helpers ──────────────────────────────────────────────
Expand Down Expand Up @@ -277,15 +225,17 @@ library ModexpBarrett {
uint256 dLen,
uint256[] memory divisor,
uint256 k
) private pure returns (uint256[] memory quotient) {
) private pure returns (uint256[] memory quotient, uint256[] memory rem) {
rem = new uint256[](k);

// Find actual length of dividend (strip leading zero limbs)
uint256 m = dLen;
while (m > 0 && dividend[m - 1] == 0) {
m--;
}
if (m == 0) {
quotient = new uint256[](1);
return quotient;
return (quotient, rem);
}

// Find effective number of significant limbs in divisor
Expand All @@ -305,13 +255,15 @@ library ModexpBarrett {
(q, remainder) = _div512by256(remainder, dividend[i], d);
quotient[i] = q;
}
return quotient;
rem[0] = remainder;
return (quotient, rem);
}

// Multi-limb divisor: Knuth Algorithm D (using kEff significant limbs)
if (m < kEff) {
quotient = new uint256[](1);
return quotient;
for (uint256 i = 0; i < m; i++) rem[i] = dividend[i];
return (quotient, rem);
}
uint256 numQlimbs = m - kEff + 1;
quotient = new uint256[](numQlimbs);
Expand Down Expand Up @@ -454,7 +406,19 @@ library ModexpBarrett {
quotient[jj] = qHat;
}

return quotient;
// Extract remainder from u[0..kEff-1] and denormalize (shift right)
if (shift > 0) {
for (uint256 i = 0; i < kEff; i++) {
rem[i] = u[i] >> shift;
if (i + 1 < kEff) {
rem[i] |= u[i + 1] << (256 - shift);
}
}
} else {
for (uint256 i = 0; i < kEff; i++) {
rem[i] = u[i];
}
}
}

// ── Schoolbook multiplication ─────────────────────────────────────
Expand Down
Loading