diff --git a/src/modexp/Modexp.sol b/src/modexp/Modexp.sol index 334bac2..64dc4db 100644 --- a/src/modexp/Modexp.sol +++ b/src/modexp/Modexp.sol @@ -18,7 +18,7 @@ library Modexp { bytes memory base, bytes memory exponent, bytes memory modulus - ) internal view returns (bytes memory result) { + ) internal pure returns (bytes memory result) { if (modulus.length == 0) return new bytes(0); // Check if modulus is odd (last byte has bit 0 set) diff --git a/src/modexp/ModexpBarrett.sol b/src/modexp/ModexpBarrett.sol index e970559..7dbbbf1 100644 --- a/src/modexp/ModexpBarrett.sol +++ b/src/modexp/ModexpBarrett.sol @@ -12,7 +12,7 @@ library ModexpBarrett { bytes memory base, bytes memory exponent, bytes memory modulus - ) internal view returns (bytes memory result) { + ) internal pure returns (bytes memory result) { uint256 modLen = modulus.length; if (modLen == 0) return new bytes(0); @@ -46,10 +46,10 @@ library ModexpBarrett { uint256 k = (modLen + 31) / 32; uint256[] memory n = _bytesToLimbs(modulus, k); - uint256[] memory a = _reduceBase(base, modulus, k); + uint256[] memory a = _reduceBase(base, n, k); // Barrett constant: mu = floor(2^(512k) / n), has k+1 limbs - uint256[] memory mu = _computeBarrettConstant(n, k, modulus); + uint256[] memory mu = _computeBarrettConstant(n, k); // one = 1 as k-limb number (used as initial accumulator) uint256[] memory r = new uint256[](k); @@ -144,83 +144,31 @@ library ModexpBarrett { } } - // ── Precompile wrapper ──────────────────────────────────────────── - - function _callPrecompile(bytes memory b, bytes memory e, bytes memory m) - private view returns (bytes memory result) - { - uint256 modLen = m.length; - result = new bytes(modLen); - bytes memory input = abi.encodePacked( - uint256(b.length), uint256(e.length), uint256(modLen), b, e, m - ); - assembly { - let ok := staticcall(gas(), 0x05, add(input, 0x20), mload(input), add(result, 0x20), modLen) - if iszero(ok) { revert(0, 0) } - } - } - - function _reduceBase(bytes memory base, bytes memory modulus, uint256 k) - private view returns (uint256[] memory) + /// @dev Reduces base mod n via schoolbook division remainder. + function _reduceBase(bytes memory base, uint256[] memory n, uint256 k) + private pure returns (uint256[] memory) { - return _bytesToLimbs(_callPrecompile(base, hex"01", modulus), k); - } - - function _uint256ToMinBytes(uint256 val) private pure returns (bytes memory) { - if (val == 0) return hex"00"; - uint256 byteLen = 0; - uint256 tmp = val; - while (tmp > 0) { - byteLen++; - tmp >>= 8; - } - bytes memory result = new bytes(byteLen); - unchecked { - for (uint256 i = 0; i < byteLen; i++) { - result[byteLen - 1 - i] = bytes1(uint8(val)); - val >>= 8; - } - } - return result; + uint256 baseLen = base.length; + if (baseLen == 0) return new uint256[](k); + uint256 baseK = (baseLen + 31) / 32; + if (baseK < k) baseK = k; + uint256[] memory baseLimbs = _bytesToLimbs(base, baseK); + (, uint256[] memory rem) = _schoolbookDiv(baseLimbs, baseK, n, k); + return rem; } // ── Barrett constant computation ────────────────────────────────── /// @dev Computes mu = floor(2^(512k) / n). - function _computeBarrettConstant(uint256[] memory n, uint256 k, bytes memory modulus) - private view returns (uint256[] memory) + function _computeBarrettConstant(uint256[] memory n, uint256 k) + private pure returns (uint256[] memory) { - // r = 2^(512k) mod n via precompile - uint256 expVal = 512 * k; - bytes memory expBytes = _uint256ToMinBytes(expVal); - bytes memory base2 = hex"02"; - uint256[] memory r = _bytesToLimbs(_callPrecompile(base2, expBytes, modulus), k); - - // dividend = 2^(512k) - r, which is exactly divisible by n - // dividend has 2k+1 limbs: limbs 0..k-1 = two's complement of r, limb 2k = 1 + // Construct 2^(512k) as a (2k+1)-limb number: all zeros except limb[2k] = 1 uint256 dLen = 2 * k + 1; uint256[] memory dividend = new uint256[](dLen); - - // Compute -r mod 2^(256k), i.e., two's complement - assembly { - let divP := add(dividend, 0x20) - let rP := add(r, 0x20) - let borrow := 0 - for { let i := 0 } lt(i, k) { i := add(i, 1) } { - let ri := mload(add(rP, mul(i, 0x20))) - let d := sub(sub(0, ri), borrow) - borrow := or(gt(ri, 0), and(iszero(ri), borrow)) - mstore(add(divP, mul(i, 0x20)), d) - } - for { let i := k } lt(i, mul(2, k)) { i := add(i, 1) } { - let d := sub(0, borrow) - mstore(add(divP, mul(i, 0x20)), d) - } - mstore(add(divP, mul(mul(2, k), 0x20)), sub(1, borrow)) - } - - // mu = dividend / n via schoolbook division - return _schoolbookDiv(dividend, dLen, n, k); + dividend[2 * k] = 1; + (uint256[] memory mu,) = _schoolbookDiv(dividend, dLen, n, k); + return mu; } // ── Division helpers ────────────────────────────────────────────── @@ -277,7 +225,9 @@ library ModexpBarrett { uint256 dLen, uint256[] memory divisor, uint256 k - ) private pure returns (uint256[] memory quotient) { + ) private pure returns (uint256[] memory quotient, uint256[] memory rem) { + rem = new uint256[](k); + // Find actual length of dividend (strip leading zero limbs) uint256 m = dLen; while (m > 0 && dividend[m - 1] == 0) { @@ -285,7 +235,7 @@ library ModexpBarrett { } if (m == 0) { quotient = new uint256[](1); - return quotient; + return (quotient, rem); } // Find effective number of significant limbs in divisor @@ -305,13 +255,15 @@ library ModexpBarrett { (q, remainder) = _div512by256(remainder, dividend[i], d); quotient[i] = q; } - return quotient; + rem[0] = remainder; + return (quotient, rem); } // Multi-limb divisor: Knuth Algorithm D (using kEff significant limbs) if (m < kEff) { quotient = new uint256[](1); - return quotient; + for (uint256 i = 0; i < m; i++) rem[i] = dividend[i]; + return (quotient, rem); } uint256 numQlimbs = m - kEff + 1; quotient = new uint256[](numQlimbs); @@ -454,7 +406,19 @@ library ModexpBarrett { quotient[jj] = qHat; } - return quotient; + // Extract remainder from u[0..kEff-1] and denormalize (shift right) + if (shift > 0) { + for (uint256 i = 0; i < kEff; i++) { + rem[i] = u[i] >> shift; + if (i + 1 < kEff) { + rem[i] |= u[i + 1] << (256 - shift); + } + } + } else { + for (uint256 i = 0; i < kEff; i++) { + rem[i] = u[i]; + } + } } // ── Schoolbook multiplication ───────────────────────────────────── diff --git a/src/modexp/ModexpMontgomery.sol b/src/modexp/ModexpMontgomery.sol index 8f40292..4fe20d3 100644 --- a/src/modexp/ModexpMontgomery.sol +++ b/src/modexp/ModexpMontgomery.sol @@ -15,7 +15,7 @@ library ModexpMontgomery { bytes memory base, bytes memory exponent, bytes memory modulus - ) internal view returns (bytes memory result) { + ) internal pure returns (bytes memory result) { uint256 modLen = modulus.length; if (modLen == 0) return new bytes(0); @@ -26,11 +26,11 @@ library ModexpMontgomery { // Convert inputs to little-endian limb arrays uint256[] memory n = _bytesToLimbs(modulus, k); - uint256[] memory a = _reduceBase(base, modulus, k); + uint256[] memory a = _reduceBase(base, n, k); // Montgomery constants uint256 n0inv = _computeN0inv(n[0]); - uint256[] memory r2 = _computeR2ModN(k, modulus); + uint256[] memory r2 = _computeR2ModN(k, n); // one = 1 as a k-limb number uint256[] memory one = new uint256[](k); @@ -126,30 +126,18 @@ library ModexpMontgomery { } } - // ── Precompile wrapper ──────────────────────────────────────────── - - /// @dev Calls the EVM modexp precompile (address 0x05). - function _callPrecompile(bytes memory b, bytes memory e, bytes memory m) - private view returns (bytes memory result) - { - uint256 modLen = m.length; - result = new bytes(modLen); - bytes memory input = abi.encodePacked( - uint256(b.length), uint256(e.length), uint256(modLen), b, e, m - ); - assembly { - let ok := staticcall(gas(), 0x05, add(input, 0x20), mload(input), add(result, 0x20), modLen) - if iszero(ok) { revert(0, 0) } - } - } - // ── Montgomery setup ────────────────────────────────────────────── - /// @dev Reduces base mod n via the precompile (base^1 mod n) and returns limbs. - function _reduceBase(bytes memory base, bytes memory modulus, uint256 k) - private view returns (uint256[] memory) + /// @dev Reduces base mod n via schoolbook division remainder. + function _reduceBase(bytes memory base, uint256[] memory n, uint256 k) + private pure returns (uint256[] memory) { - return _bytesToLimbs(_callPrecompile(base, hex"01", modulus), k); + uint256 baseLen = base.length; + if (baseLen == 0) return new uint256[](k); + uint256 baseK = (baseLen + 31) / 32; + if (baseK < k) baseK = k; + uint256[] memory baseLimbs = _bytesToLimbs(base, baseK); + return _schoolbookRem(baseLimbs, baseK, n, k); } /// @dev Computes -n^{-1} mod 2^256 via Newton's method (8 doubling steps). @@ -163,34 +151,227 @@ library ModexpMontgomery { } } - /// @dev Computes R^2 mod n where R = 2^{256k}, via precompile: 2^{512k} mod n. - function _computeR2ModN(uint256 k, bytes memory modulus) - private view returns (uint256[] memory) + /// @dev Computes R^2 mod n where R = 2^{256k}. + function _computeR2ModN(uint256 k, uint256[] memory n) + private pure returns (uint256[] memory) { - uint256 expVal = 512 * k; - bytes memory expBytes = _uint256ToMinBytes(expVal); - bytes memory base2 = new bytes(1); - base2[0] = 0x02; - return _bytesToLimbs(_callPrecompile(base2, expBytes, modulus), k); + uint256 dLen = 2 * k + 1; + uint256[] memory dividend = new uint256[](dLen); + dividend[2 * k] = 1; + return _schoolbookRem(dividend, dLen, n, k); } - /// @dev Encodes a uint256 as minimal big-endian bytes (no leading zeros). - function _uint256ToMinBytes(uint256 val) private pure returns (bytes memory) { - if (val == 0) return hex"00"; - uint256 byteLen = 0; - uint256 tmp = val; - while (tmp > 0) { - byteLen++; - tmp >>= 8; + // ── Division helpers ────────────────────────────────────────────── + + /// @dev 512-by-256 division: (hi:lo) / d -> (quotient, remainder). + function _div512by256(uint256 hi, uint256 lo, uint256 d) + private pure returns (uint256 q, uint256 rem) + { + assembly { + if iszero(hi) { + q := div(lo, d) + rem := mod(lo, d) + } + if gt(hi, 0) { + let r256 := addmod(mod(not(0), d), 1, d) + rem := addmod(mulmod(hi, r256, d), lo, d) + + let lo_e := sub(lo, rem) + let hi_e := sub(hi, lt(lo, rem)) + + let twos := and(d, sub(0, d)) + d := div(d, twos) + + lo_e := div(lo_e, twos) + let flip := add(div(sub(0, twos), twos), 1) + lo_e := or(lo_e, mul(hi_e, flip)) + + let inv := xor(mul(3, d), 2) + inv := mul(inv, sub(2, mul(d, inv))) + inv := mul(inv, sub(2, mul(d, inv))) + inv := mul(inv, sub(2, mul(d, inv))) + inv := mul(inv, sub(2, mul(d, inv))) + inv := mul(inv, sub(2, mul(d, inv))) + inv := mul(inv, sub(2, mul(d, inv))) + + q := mul(lo_e, inv) + } } - bytes memory result = new bytes(byteLen); - unchecked { - for (uint256 i = 0; i < byteLen; i++) { - result[byteLen - 1 - i] = bytes1(uint8(val)); - val >>= 8; + } + + /// @dev Schoolbook long division returning remainder only (Knuth Algorithm D). + function _schoolbookRem( + uint256[] memory dividend, uint256 dLen, + uint256[] memory divisor, uint256 k + ) private pure returns (uint256[] memory remainder) { + remainder = new uint256[](k); + + uint256 m = dLen; + while (m > 0 && dividend[m - 1] == 0) m--; + if (m == 0) return remainder; + + uint256 kEff = k; + while (kEff > 1 && divisor[kEff - 1] == 0) kEff--; + + // Single-limb divisor + if (kEff == 1) { + uint256 d = divisor[0]; + uint256 rem = 0; + for (uint256 i = m; i > 0;) { + unchecked { i--; } + (, rem) = _div512by256(rem, dividend[i], d); + } + remainder[0] = rem; + return remainder; + } + + // Dividend shorter than divisor: dividend IS the remainder + if (m < kEff) { + for (uint256 i = 0; i < m; i++) remainder[i] = dividend[i]; + return remainder; + } + + uint256 numQlimbs = m - kEff + 1; + + uint256[] memory u = new uint256[](m + 1); + assembly { mcopy(add(u, 0x20), add(dividend, 0x20), mul(m, 0x20)) } + + // Normalize: shift divisor so top limb has high bit set + uint256 topD = divisor[kEff - 1]; + uint256 shift = 0; + { + uint256 tmp = topD; + while (tmp < (1 << 255)) { tmp <<= 1; shift++; } + } + + uint256[] memory v = new uint256[](kEff); + if (shift > 0) { + uint256 carry = 0; + for (uint256 i = 0; i < kEff; i++) { + uint256 newVal = (divisor[i] << shift) | carry; + carry = divisor[i] >> (256 - shift); + v[i] = newVal; + } + carry = 0; + for (uint256 i = 0; i < m; i++) { + uint256 newVal = (u[i] << shift) | carry; + carry = u[i] >> (256 - shift); + u[i] = newVal; + } + u[m] = carry; + } else { + assembly { mcopy(add(v, 0x20), add(divisor, 0x20), mul(kEff, 0x20)) } + } + + uint256 vTop = v[kEff - 1]; + + for (uint256 jj = numQlimbs; jj > 0;) { + unchecked { jj--; } + uint256 uHi = u[jj + kEff]; + uint256 uLo = u[jj + kEff - 1]; + + uint256 qHat; + { + uint256 rHat; + bool doRefinement; + if (uHi >= vTop) { + qHat = type(uint256).max; + rHat = uLo + vTop; + doRefinement = (rHat >= uLo); + } else { + (qHat, rHat) = _div512by256(uHi, uLo, vTop); + doRefinement = true; + } + + if (doRefinement && kEff >= 2) { + uint256 vSecond = v[kEff - 2]; + uint256 uSecond = u[jj + kEff - 2]; + assembly { + let qvLo := mul(qHat, vSecond) + let qvMM := mulmod(qHat, vSecond, not(0)) + let qvHi := sub(sub(qvMM, qvLo), lt(qvMM, qvLo)) + + for {} or(gt(qvHi, rHat), and(eq(qvHi, rHat), gt(qvLo, uSecond))) {} { + qHat := sub(qHat, 1) + rHat := add(rHat, vTop) + if lt(rHat, vTop) { break } + qvLo := mul(qHat, vSecond) + qvMM := mulmod(qHat, vSecond, not(0)) + qvHi := sub(sub(qvMM, qvLo), lt(qvMM, qvLo)) + } + } + } + } + bool negative; + assembly { + let uP := add(u, 0x20) + let vP := add(v, 0x20) + let carry := 0 + let borrow := 0 + + for { let i := 0 } lt(i, kEff) { i := add(i, 1) } { + let vi := mload(add(vP, mul(i, 0x20))) + let pLo := mul(qHat, vi) + let pMM := mulmod(qHat, vi, not(0)) + let pH := sub(sub(pMM, pLo), lt(pMM, pLo)) + + let withCarry := add(pLo, carry) + let newCarry := add(pH, lt(withCarry, pLo)) + carry := newCarry + + let uOff := add(uP, mul(add(jj, i), 0x20)) + let uVal := mload(uOff) + let diff := sub(uVal, withCarry) + let newBorrow := lt(uVal, withCarry) + let diff2 := sub(diff, borrow) + newBorrow := or(newBorrow, lt(diff, borrow)) + borrow := newBorrow + mstore(uOff, diff2) + } + + let uTopOff := add(uP, mul(add(jj, kEff), 0x20)) + let uTopVal := mload(uTopOff) + let diff := sub(uTopVal, carry) + let nb := lt(uTopVal, carry) + let diff2 := sub(diff, borrow) + nb := or(nb, lt(diff, borrow)) + mstore(uTopOff, diff2) + negative := nb + } + + if (negative) { + assembly { + let uP := add(u, 0x20) + let vP := add(v, 0x20) + let carry := 0 + for { let i := 0 } lt(i, kEff) { i := add(i, 1) } { + let uOff := add(uP, mul(add(jj, i), 0x20)) + let s := add(mload(uOff), mload(add(vP, mul(i, 0x20)))) + let c1 := lt(s, mload(uOff)) + let s2 := add(s, carry) + let c2 := lt(s2, s) + mstore(uOff, s2) + carry := or(c1, c2) + } + let uTopOff := add(uP, mul(add(jj, kEff), 0x20)) + mstore(uTopOff, add(mload(uTopOff), carry)) + } + } + } + + // Extract remainder from u[0..kEff-1] and denormalize (shift right) + if (shift > 0) { + for (uint256 i = 0; i < kEff; i++) { + remainder[i] = u[i] >> shift; + if (i + 1 < kEff) { + remainder[i] |= u[i + 1] << (256 - shift); + } + } + } else { + for (uint256 i = 0; i < kEff; i++) { + remainder[i] = u[i]; } } - return result; } // ── Square-and-multiply ───────────────────────────────────────────