diff --git a/Cargo.lock b/Cargo.lock index 6e32e307..f7c0c812 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1400,6 +1400,7 @@ dependencies = [ "ff 0.13.1", "generic-array", "group 0.13.0", + "hkdf", "pem-rfc7468", "pkcs8", "rand_core 0.6.4", @@ -1941,6 +1942,15 @@ version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6fe2267d4ed49bc07b63801559be28c718ea06c4738b7a03c94df7386d2cde46" +[[package]] +name = "hkdf" +version = "0.12.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5f8eb2ad728638ea2c7d47a21db23b7b58a72ed6a38256b8a1849f15fbbdf7" +dependencies = [ + "hmac", +] + [[package]] name = "hmac" version = "0.12.1" @@ -3547,6 +3557,7 @@ name = "pico-patch-libs" version = "1.2.2" dependencies = [ "bincode", + "elliptic-curve", "serde", ] diff --git a/Cargo.toml b/Cargo.toml index dbaefd04..8fe2a730 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -112,7 +112,7 @@ dashmap = "6.1.0" dashu = "0.4.2" derive_more = { version = "2.0", features = ["constructor"] } elf = "0.7.4" -elliptic-curve = "0.13.8" +elliptic-curve = { version = "0.13.8", features = ["ecdh", "hazmat", "sec1"] } env_logger = "0.11.6" eyre = "0.6.12" ff = { version = "0.13", features = ["derive", "derive_bits"] } diff --git a/sdk/cli/src/subcommand/config.toml b/sdk/cli/src/subcommand/config.toml index cfcd918d..7a285768 100644 --- a/sdk/cli/src/subcommand/config.toml +++ b/sdk/cli/src/subcommand/config.toml @@ -1,5 +1,5 @@ [build] -target = ["riscv32im-risc0-zkvm-elf"] +target = ["riscv64im-pico-zkvm-elf"] extended = true tools = ["cargo", "cargo-clippy", "clippy", "rustfmt"] configure-args = [] diff --git a/sdk/cli/src/subcommand/prove.rs b/sdk/cli/src/subcommand/prove.rs index a6a6f384..965863ca 100644 --- a/sdk/cli/src/subcommand/prove.rs +++ b/sdk/cli/src/subcommand/prove.rs @@ -94,7 +94,7 @@ impl ProveCmd { .parent() .unwrap() .join(DEFAULT_ELF_DIR) - .join("riscv32im-pico-zkvm-elf") + .join("riscv64im-pico-zkvm-elf") } }; let elf: Vec = std::fs::read(elf_path)?; diff --git a/sdk/patch-libs/Cargo.toml b/sdk/patch-libs/Cargo.toml index 7de81c5b..7c82424b 100644 --- a/sdk/patch-libs/Cargo.toml +++ b/sdk/patch-libs/Cargo.toml @@ -9,4 +9,5 @@ categories = { workspace = true } [dependencies] bincode.workspace = true +elliptic-curve.workspace = true serde.workspace = true diff --git a/sdk/patch-libs/src/bls12381.rs b/sdk/patch-libs/src/bls12381.rs index 4938cee5..6e3e0144 100644 --- a/sdk/patch-libs/src/bls12381.rs +++ b/sdk/patch-libs/src/bls12381.rs @@ -6,11 +6,11 @@ use crate::{ }; /// The number of limbs in [Bls12381AffinePoint]. -pub const N: usize = 24; +pub const N: usize = 12; /// A point on the BLS12-381 curve. #[derive(Copy, Clone)] -#[repr(align(4))] +#[repr(align(8))] pub struct Bls12381Point(pub WeierstrassPoint); impl WeierstrassAffinePoint for Bls12381Point { @@ -24,38 +24,53 @@ impl WeierstrassAffinePoint for Bls12381Point { } impl AffinePoint for Bls12381Point { + const GENERATOR: [u64; N] = [ + 18103045581585958587, + 7806400890582735599, + 11623291730934869080, + 14080658508445169925, + 2780237799254240271, + 1725392847304644500, + 912580534683953121, + 15005087156090211044, + 61670280795567085, + 18227722000993880822, + 11573741888802228964, + 627113611842199793, + ]; + /// The generator was taken from "py_ecc" python library by the Ethereum Foundation: /// /// https://github.com/ethereum/py_ecc/blob/7b9e1b3/py_ecc/bls12_381/bls12_381_curve.py#L38-L45 - const GENERATOR: [u32; N] = [ - 3676489403, 4214943754, 4185529071, 1817569343, 387689560, 2706258495, 2541009157, - 3278408783, 1336519695, 647324556, 832034708, 401724327, 1187375073, 212476713, 2726857444, - 3493644100, 738505709, 14358731, 3587181302, 4243972245, 1948093156, 2694721773, - 3819610353, 146011265, - ]; + #[allow(deprecated)] + const GENERATOR_T: Self = Self(WeierstrassPoint::Affine(Self::GENERATOR)); - fn new(limbs: [u32; N]) -> Self { + fn new(limbs: [u64; N]) -> Self { Self(WeierstrassPoint::Affine(limbs)) } - fn limbs_ref(&self) -> &[u32; N] { + fn identity() -> Self { + Self::infinity() + } + + fn is_identity(&self) -> bool { + self.is_infinity() + } + + fn limbs_ref(&self) -> &[u64; N] { match &self.0 { WeierstrassPoint::Infinity => panic!("Infinity point has no limbs"), WeierstrassPoint::Affine(limbs) => limbs, } } - fn limbs_mut(&mut self) -> &mut [u32; N] { + fn limbs_mut(&mut self) -> &mut [u64; N] { match &mut self.0 { WeierstrassPoint::Infinity => panic!("Infinity point has no limbs"), WeierstrassPoint::Affine(limbs) => limbs, } } - fn complete_add_assign(&mut self, other: &Self) { - self.weierstrass_add_assign(other); - } - fn add_assign(&mut self, other: &Self) { let a = self.limbs_mut(); let b = other.limbs_ref(); @@ -64,6 +79,10 @@ impl AffinePoint for Bls12381Point { } } + fn complete_add_assign(&mut self, other: &Self) { + self.weierstrass_add_assign(other); + } + fn double(&mut self) { let a = self.limbs_mut(); unsafe { @@ -73,12 +92,22 @@ impl AffinePoint for Bls12381Point { } /// Decompresses a compressed public key using bls12381_decompress precompile. -pub fn decompress_pubkey(compressed_key: &[u8; 48]) -> Result<[u8; 96], ErrorKind> { - let mut decompressed_key = [0u8; 96]; - decompressed_key[..48].copy_from_slice(compressed_key); +pub fn decompress_pubkey(compressed_key: &[u64; 6]) -> Result<[u64; 12], ErrorKind> { + let mut decompressed_key = [0u64; 12]; + decompressed_key[..6].copy_from_slice(compressed_key); + + // The sign bit is stored in the first byte, so we have to access it like this. + let mut decompressed_key = decompressed_key.map(u64::to_ne_bytes); + + // The sign bit is the third most significant bit (beginning the count at "first"). + const SIGN_OFFSET: u32 = 3; + const SIGN_MASK: u8 = 1u8 << (u8::BITS - SIGN_OFFSET); + let sign_bit = (decompressed_key[0][0] & SIGN_MASK) != 0; + decompressed_key[0][0] <<= SIGN_OFFSET; + decompressed_key[0][0] >>= SIGN_OFFSET; + + let mut decompressed_key = decompressed_key.map(u64::from_ne_bytes); - let sign_bit = ((decompressed_key[0] & 0b_0010_0000) >> 5) == 1; - decompressed_key[0] &= 0b_0001_1111; unsafe { syscall_bls12381_decompress(&mut decompressed_key, sign_bit); } diff --git a/sdk/patch-libs/src/bn254.rs b/sdk/patch-libs/src/bn254.rs index a541bc86..ae7cf464 100644 --- a/sdk/patch-libs/src/bn254.rs +++ b/sdk/patch-libs/src/bn254.rs @@ -4,11 +4,11 @@ use crate::{ }; /// The number of limbs in [Bn254AffinePoint]. -pub const N: usize = 16; +pub const N: usize = 8; /// A point on the Bn254 curve. #[derive(Copy, Clone)] -#[repr(align(4))] +#[repr(align(8))] pub struct Bn254Point(pub WeierstrassPoint); impl WeierstrassAffinePoint for Bn254Point { @@ -22,33 +22,40 @@ impl WeierstrassAffinePoint for Bn254Point { } impl AffinePoint for Bn254Point { + const GENERATOR: [u64; N] = [1, 0, 0, 0, 2, 0, 0, 0]; + + #[allow(deprecated)] /// The generator has been taken from py_pairing python library by the Ethereum Foundation: /// /// https://github.com/ethereum/py_pairing/blob/5f609da/py_ecc/bn128/bn128_field_elements.py - const GENERATOR: [u32; N] = [1, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0]; + const GENERATOR_T: Self = Self(WeierstrassPoint::Affine(Self::GENERATOR)); - fn new(limbs: [u32; N]) -> Self { + fn new(limbs: [u64; N]) -> Self { Self(WeierstrassPoint::Affine(limbs)) } - fn limbs_ref(&self) -> &[u32; N] { + fn identity() -> Self { + Self::infinity() + } + + fn is_identity(&self) -> bool { + self.is_infinity() + } + + fn limbs_ref(&self) -> &[u64; N] { match &self.0 { WeierstrassPoint::Infinity => panic!("Infinity point has no limbs"), WeierstrassPoint::Affine(limbs) => limbs, } } - fn limbs_mut(&mut self) -> &mut [u32; N] { + fn limbs_mut(&mut self) -> &mut [u64; N] { match &mut self.0 { WeierstrassPoint::Infinity => panic!("Infinity point has no limbs"), WeierstrassPoint::Affine(limbs) => limbs, } } - fn complete_add_assign(&mut self, other: &Self) { - self.weierstrass_add_assign(other); - } - fn add_assign(&mut self, other: &Self) { let a = self.limbs_mut(); let b = other.limbs_ref(); @@ -57,6 +64,10 @@ impl AffinePoint for Bn254Point { } } + fn complete_add_assign(&mut self, other: &Self) { + self.weierstrass_add_assign(other); + } + fn double(&mut self) { let a = self.limbs_mut(); unsafe { diff --git a/sdk/patch-libs/src/ecdsa.rs b/sdk/patch-libs/src/ecdsa.rs new file mode 100644 index 00000000..1ea43b4b --- /dev/null +++ b/sdk/patch-libs/src/ecdsa.rs @@ -0,0 +1,111 @@ +//! An implementation of the types needed for [`CurveArithmetic`]. +//! +//! [`CurveArithmetic`] is a trait that is used in [RustCryptos ECDSA](https://github.com/RustCrypto/signatures). +//! +//! [`CurveArithmetic`] contains all the types needed to implement the ECDSA algorithm over some +//! curve. +//! +//! This implementation is specifically for use inside of Pico zkVM, and internally uses Pico's Weierstrass +//! precompiles. +//! +//! In summary, Pico overrides curve arithmetic entirely, and patches upstream field operations +//! to be more efficient in the VM, such as `sqrt` or `inverse`. + +use crate::utils::AffinePoint as PicoAffinePointTrait; + +use elliptic_curve::{ + ff, generic_array::typenum::consts::U32, subtle::CtOption, CurveArithmetic, FieldBytes, +}; +use std::{fmt::Debug, ops::Neg}; + +/// The affine point type for Pico. +pub mod affine; +pub use affine::AffinePoint; + +/// The projective point type for Pico. +pub mod projective; +pub use projective::ProjectivePoint; + +/// NOTE: The only supported ECDSA curves are secp256k1 and secp256r1, which both +/// have 8 limbs in their field elements. +const POINT_LIMBS: usize = 4 * 2; + +/// The number of bytes in a field element as an [`usize`]. +const FIELD_BYTES_SIZE_USIZE: usize = 32; + +/// The number of bytes in a field element as an [`elliptic_curve::generic_array::U32`]. +#[allow(non_camel_case_types)] +type FIELD_BYTES_SIZE = U32; + +/// A [`CurveArithmetic`] extension for Pico acceleration. +/// +/// Patched crates implement this trait to take advantage of Pico-specific acceleration in the zkVM +/// context. +/// +/// Note: This trait only supports 32 byte base field curves. +pub trait ECDSACurve +where + Self: CurveArithmetic< + FieldBytesSize = FIELD_BYTES_SIZE, + AffinePoint = AffinePoint, + ProjectivePoint = ProjectivePoint, + >, +{ + type FieldElement: Field + Neg; + + /// The underlying [`PicoAffinePointTrait`] implementation. + type PicoAffinePoint: ECDSAPoint; + + /// The `a` coefficient in the curve equation. + const EQUATION_A: Self::FieldElement; + + /// The `b` coefficient in the curve equation. + const EQUATION_B: Self::FieldElement; +} + +/// Alias trait for the [`ff::PrimeField`] with 32 byte field elements. +/// +/// Note: All bytes should be considered to be in big-endian format. +pub trait Field: ff::PrimeField { + /// Create an instance of self from a FieldBytes. + fn from_bytes(bytes: &FieldBytes) -> CtOption; + + /// Convert self to a FieldBytes. + /// + /// Note: Implementers should ensure these methods normalize first. + fn to_bytes(self) -> FieldBytes; + + /// Ensure the field element is normalized. + fn normalize(self) -> Self; +} + +pub type FieldElement = ::FieldElement; + +/// Alias trait for the [`PicoAffinePointTrait`] with 32 byte field elements. +pub trait ECDSAPoint: + PicoAffinePointTrait + Clone + Copy + Debug + Send + Sync +{ + #[inline] + fn from(x: &[u8], y: &[u8]) -> Self { + >::from(x, y) + } +} + +impl

ECDSAPoint for P where + P: PicoAffinePointTrait + Clone + Copy + Debug + Send + Sync +{ +} + +pub mod ecdh { + pub use elliptic_curve::ecdh::{diffie_hellman, EphemeralSecret, SharedSecret}; + + use super::{AffinePoint, ECDSACurve, Field}; + + impl From<&AffinePoint> for SharedSecret { + fn from(affine: &AffinePoint) -> SharedSecret { + let (x, _) = affine.field_elements(); + + x.to_bytes().into() + } + } +} diff --git a/sdk/patch-libs/src/ecdsa/affine.rs b/sdk/patch-libs/src/ecdsa/affine.rs new file mode 100644 index 00000000..a038ea95 --- /dev/null +++ b/sdk/patch-libs/src/ecdsa/affine.rs @@ -0,0 +1,248 @@ +//! Implementation of an affine point, with acceleration for operations in the context of Pico. +//! +//! The [`crate::ecdsa::ProjectivePoint`] type is mainly used in the `ecdsa-core` algorithms, +//! however, in some cases, the affine point is required. +//! +//! Note: When performing curve operations, accelerated crates for Pico use affine arithmetic instead +//! of projective arithmetic for performance. + +#![allow(deprecated)] +use super::{ + ECDSACurve, ECDSAPoint, Field, FieldElement, PicoAffinePointTrait, FIELD_BYTES_SIZE_USIZE, +}; + +use elliptic_curve::{ + ff::Field as _, + group::GroupEncoding, + point::{AffineCoordinates, DecompactPoint, DecompressPoint}, + sec1::{self, CompressedPoint, EncodedPoint, FromEncodedPoint, ToEncodedPoint}, + subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}, + zeroize::DefaultIsZeroes, + FieldBytes, PrimeField, +}; +use std::ops::Neg; + +#[derive(Clone, Copy, Debug)] +pub struct AffinePoint { + pub inner: C::PicoAffinePoint, +} + +impl AffinePoint { + /// Create an affine point from the given field elements, without checking if the point is on + /// the curve. + pub fn from_field_elements_unchecked(x: FieldElement, y: FieldElement) -> Self { + let mut x_slice = x.to_bytes(); + let x_slice = x_slice.as_mut_slice(); + x_slice.reverse(); + + let mut y_slice = y.to_bytes(); + let y_slice = y_slice.as_mut_slice(); + y_slice.reverse(); + + AffinePoint { + inner: ::from(x_slice, y_slice), + } + } + + /// Get the x and y field elements of the point. + /// + /// The returned elements are always normalized. + pub fn field_elements(&self) -> (FieldElement, FieldElement) { + if self.is_identity().into() { + return (FieldElement::::ZERO, FieldElement::::ZERO); + } + + let bytes = self.inner.to_le_bytes(); + + let mut x_bytes: [u8; FIELD_BYTES_SIZE_USIZE] = + bytes[..FIELD_BYTES_SIZE_USIZE].try_into().unwrap(); + + x_bytes.reverse(); + + let mut y_bytes: [u8; FIELD_BYTES_SIZE_USIZE] = + bytes[FIELD_BYTES_SIZE_USIZE..].try_into().unwrap(); + + y_bytes.reverse(); + + let x = FieldElement::::from_bytes(&x_bytes.into()).unwrap(); + let y = FieldElement::::from_bytes(&y_bytes.into()).unwrap(); + (x, y) + } + + /// Get the generator point. + pub fn generator() -> Self { + AffinePoint { + inner: C::PicoAffinePoint::GENERATOR_T, + } + } + + /// Get the identity point. + pub fn identity() -> Self { + AffinePoint { + inner: C::PicoAffinePoint::identity(), + } + } + + /// Check if the point is the identity point. + pub fn is_identity(&self) -> Choice { + Choice::from(self.inner.is_identity() as u8) + } +} + +impl FromEncodedPoint for AffinePoint { + fn from_encoded_point(point: &EncodedPoint) -> CtOption { + match point.coordinates() { + sec1::Coordinates::Identity => CtOption::new(Self::identity(), 1.into()), + sec1::Coordinates::Compact { x } => Self::decompact(x), + sec1::Coordinates::Compressed { x, y_is_odd } => { + AffinePoint::::decompress(x, Choice::from(y_is_odd as u8)) + } + sec1::Coordinates::Uncompressed { x, y } => { + let x = FieldElement::::from_bytes(x); + let y = FieldElement::::from_bytes(y); + + x.and_then(|x| { + y.and_then(|y| { + // Ensure the point is on the curve. + let lhs = (y * y).normalize(); + let rhs = (x * x * x) + (C::EQUATION_A * x) + C::EQUATION_B; + + let point = Self::from_field_elements_unchecked(x, y); + + CtOption::new(point, lhs.ct_eq(&rhs.normalize())) + }) + }) + } + } + } +} + +impl ToEncodedPoint for AffinePoint { + fn to_encoded_point(&self, compress: bool) -> EncodedPoint { + // If the point is the identity point, just return the identity point. + if self.is_identity().into() { + return EncodedPoint::::identity(); + } + + let (x, y) = self.field_elements(); + + // The field elements are already normalized by virtue of being created via `FromBytes`. + EncodedPoint::::from_affine_coordinates(&x.to_bytes(), &y.to_bytes(), compress) + } +} + +impl DecompressPoint for AffinePoint { + fn decompress(x_bytes: &FieldBytes, y_is_odd: Choice) -> CtOption { + FieldElement::::from_bytes(x_bytes).and_then(|x| { + let alpha = (x * x * x) + (C::EQUATION_A * x) + C::EQUATION_B; + let beta = alpha.sqrt(); + + beta.map(|beta| { + // Ensure the element is normalized for consistency. + let beta = beta.normalize(); + + let y = FieldElement::::conditional_select( + &beta.neg(), + &beta, + beta.is_odd().ct_eq(&y_is_odd), + ); + + // X is normalized by virtue of being created via `FromBytes`. + AffinePoint::from_field_elements_unchecked(x, y.normalize()) + }) + }) + } +} + +impl DecompactPoint for AffinePoint { + fn decompact(x_bytes: &FieldBytes) -> CtOption { + Self::decompress(x_bytes, Choice::from(0)) + } +} + +impl AffineCoordinates for AffinePoint { + type FieldRepr = FieldBytes; + + fn x(&self) -> FieldBytes { + let (x, _) = self.field_elements(); + + x.to_bytes() + } + + fn y_is_odd(&self) -> Choice { + let (_, y) = self.field_elements(); + + // As field elements are created via [`Field::from_bytes`], they are already normalized. + y.is_odd() + } +} + +impl ConditionallySelectable for AffinePoint { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + // Conditional select is a constant time if-else operation. + // + // In the Pico vm, there are no attempts made to prevent side channel attacks. + if choice.into() { + *b + } else { + *a + } + } +} + +impl ConstantTimeEq for AffinePoint { + fn ct_eq(&self, other: &Self) -> Choice { + let (x1, y1) = self.field_elements(); + let (x1, y1) = (x1, y1); + + let (x2, y2) = other.field_elements(); + let (x2, y2) = (x2, y2); + + // These are already normalized by virtue of being created via `FromBytes`. + x1.ct_eq(&x2) & y1.ct_eq(&y2) + } +} + +impl PartialEq for AffinePoint { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for AffinePoint {} + +impl Default for AffinePoint { + fn default() -> Self { + AffinePoint::identity() + } +} + +impl DefaultIsZeroes for AffinePoint {} + +impl GroupEncoding for AffinePoint { + type Repr = CompressedPoint; + + fn from_bytes(bytes: &Self::Repr) -> CtOption { + EncodedPoint::::from_bytes(bytes) + .map(|point| CtOption::new(point, Choice::from(1))) + .unwrap_or_else(|_| { + // SEC1 identity encoding is technically 1-byte 0x00, but the + // `GroupEncoding` API requires a fixed-width `Repr`. + let is_identity = bytes.ct_eq(&Self::Repr::default()); + CtOption::new(EncodedPoint::::identity(), is_identity) + }) + .and_then(|point| Self::from_encoded_point(&point)) + } + + fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption { + // There is no unchecked conversion for compressed points. + Self::from_bytes(bytes) + } + + fn to_bytes(&self) -> Self::Repr { + let encoded = self.to_encoded_point(true); + let mut result = CompressedPoint::::default(); + result[..encoded.len()].copy_from_slice(encoded.as_bytes()); + result + } +} diff --git a/sdk/patch-libs/src/ecdsa/projective.rs b/sdk/patch-libs/src/ecdsa/projective.rs new file mode 100644 index 00000000..03a1bfd1 --- /dev/null +++ b/sdk/patch-libs/src/ecdsa/projective.rs @@ -0,0 +1,441 @@ +//! Implementation of the Pico accelerated projective point. The projective point wraps the affine +//! point. +//! +//! This type is mainly used in the `ecdsa-core` algorithms. +//! +//! Note: When performing curve operations, accelerated crates for Pico use affine arithmetic instead +//! of projective arithmetic for performance. + +use super::{AffinePoint, ECDSACurve, PicoAffinePointTrait}; + +use elliptic_curve::{ + group::{cofactor::CofactorGroup, prime::PrimeGroup}, + ops::MulByGenerator, + sec1::{CompressedPoint, ModulusSize}, + CurveArithmetic, FieldBytes, +}; + +use elliptic_curve::{ + ff::{Field, PrimeField}, + group::{Curve, Group, GroupEncoding}, + ops::LinearCombination, + rand_core::RngCore, + subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}, + zeroize::DefaultIsZeroes, +}; + +use std::{ + iter::Sum, + ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +use std::borrow::Borrow; + +/// The Pico accelerated projective point. +#[derive(Clone, Copy, Debug)] +pub struct ProjectivePoint { + /// The inner affine point. + /// + /// Pico uses affine arithmetic for all operations. + pub inner: AffinePoint, +} + +impl ProjectivePoint { + pub fn identity() -> Self { + ProjectivePoint { + inner: AffinePoint::::identity(), + } + } + + /// Convert the projective point to an affine point. + pub fn to_affine(self) -> AffinePoint { + self.inner + } + + fn to_zkvm_point(self) -> C::PicoAffinePoint { + self.inner.inner + } + + fn as_zkvm_point(&self) -> &C::PicoAffinePoint { + &self.inner.inner + } + + fn as_mut_zkvm_point(&mut self) -> &mut C::PicoAffinePoint { + &mut self.inner.inner + } + + /// Check if the point is the identity point. + pub fn is_identity(&self) -> Choice { + self.inner.is_identity() + } + + fn from_zkvm_point(p: C::PicoAffinePoint) -> Self { + Self { + inner: AffinePoint { inner: p }, + } + } + + pub fn double(&self) -> Self { + ::double(self) + } +} + +impl From> for ProjectivePoint { + fn from(p: AffinePoint) -> Self { + ProjectivePoint { inner: p } + } +} + +impl From<&AffinePoint> for ProjectivePoint { + fn from(p: &AffinePoint) -> Self { + ProjectivePoint { inner: *p } + } +} + +impl From> for AffinePoint { + fn from(p: ProjectivePoint) -> Self { + p.inner + } +} + +impl From<&ProjectivePoint> for AffinePoint { + fn from(p: &ProjectivePoint) -> Self { + p.inner + } +} + +impl Group for ProjectivePoint { + type Scalar = ::Scalar; + + fn identity() -> Self { + Self::identity() + } + + fn random(rng: impl RngCore) -> Self { + ProjectivePoint::::generator() * Self::Scalar::random(rng) + } + + fn double(&self) -> Self { + *self + self + } + + fn generator() -> Self { + Self { + inner: AffinePoint::::generator(), + } + } + + fn is_identity(&self) -> Choice { + self.inner.is_identity() + } +} + +impl Curve for ProjectivePoint { + type AffineRepr = AffinePoint; + + fn to_affine(&self) -> Self::AffineRepr { + self.inner + } +} + +impl MulByGenerator for ProjectivePoint {} + +impl LinearCombination for ProjectivePoint { + fn lincomb(x: &Self, k: &Self::Scalar, y: &Self, l: &Self::Scalar) -> Self { + let x = x.to_zkvm_point(); + let y = y.to_zkvm_point(); + + let a_bits_le = be_bytes_to_le_bits(k.to_repr().as_ref()); + let b_bits_le = be_bytes_to_le_bits(l.to_repr().as_ref()); + + let pico_point = + C::PicoAffinePoint::multi_scalar_multiplication(&a_bits_le, x, &b_bits_le, y); + + Self::from_zkvm_point(pico_point) + } +} + +// Implementation of scalar multiplication for the projective point. + +impl> Mul for ProjectivePoint { + type Output = ProjectivePoint; + + fn mul(mut self, rhs: T) -> Self::Output { + let pico_point = self.as_mut_zkvm_point(); + pico_point.mul_assign(&be_bytes_to_le_words(rhs.borrow().to_repr())); + + self + } +} + +impl> MulAssign for ProjectivePoint { + fn mul_assign(&mut self, rhs: T) { + self.as_mut_zkvm_point() + .mul_assign(&be_bytes_to_le_words(rhs.borrow().to_repr())); + } +} + +// Implementation of projective arithmetic. + +impl Neg for ProjectivePoint { + type Output = ProjectivePoint; + + fn neg(self) -> Self::Output { + if self.is_identity().into() { + return self; + } + + let point = self.to_affine(); + let (x, y) = point.field_elements(); + + AffinePoint::::from_field_elements_unchecked(x, y.neg()).into() + } +} + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(mut self, rhs: ProjectivePoint) -> Self::Output { + self.as_mut_zkvm_point().add_assign(rhs.as_zkvm_point()); + + self + } +} + +impl Add<&ProjectivePoint> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(mut self, rhs: &ProjectivePoint) -> Self::Output { + self.as_mut_zkvm_point().add_assign(rhs.as_zkvm_point()); + + self + } +} + +impl Sub> for ProjectivePoint { + type Output = ProjectivePoint; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: ProjectivePoint) -> Self::Output { + self + rhs.neg() + } +} + +impl Sub<&ProjectivePoint> for ProjectivePoint { + type Output = ProjectivePoint; + + #[allow(clippy::suspicious_arithmetic_impl)] + fn sub(self, rhs: &ProjectivePoint) -> Self::Output { + self + (*rhs).neg() + } +} + +impl Sum> for ProjectivePoint { + fn sum>(iter: I) -> Self { + iter.fold(Self::identity(), |a, b| a + b) + } +} + +impl<'a, C: ECDSACurve> Sum<&'a ProjectivePoint> for ProjectivePoint { + fn sum>>(iter: I) -> Self { + iter.cloned().sum() + } +} + +impl AddAssign> for ProjectivePoint { + fn add_assign(&mut self, rhs: ProjectivePoint) { + self.as_mut_zkvm_point().add_assign(rhs.as_zkvm_point()); + } +} + +impl AddAssign<&ProjectivePoint> for ProjectivePoint { + fn add_assign(&mut self, rhs: &ProjectivePoint) { + self.as_mut_zkvm_point().add_assign(rhs.as_zkvm_point()); + } +} + +impl SubAssign> for ProjectivePoint { + fn sub_assign(&mut self, rhs: ProjectivePoint) { + self.as_mut_zkvm_point() + .add_assign(rhs.neg().as_zkvm_point()); + } +} + +impl SubAssign<&ProjectivePoint> for ProjectivePoint { + fn sub_assign(&mut self, rhs: &ProjectivePoint) { + self.as_mut_zkvm_point() + .add_assign(rhs.neg().as_zkvm_point()); + } +} + +impl Default for ProjectivePoint { + fn default() -> Self { + Self::identity() + } +} + +// Implementation of mixed arithmetic. + +impl Add> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: AffinePoint) -> Self::Output { + self + ProjectivePoint { inner: rhs } + } +} + +impl Add<&AffinePoint> for ProjectivePoint { + type Output = ProjectivePoint; + + fn add(self, rhs: &AffinePoint) -> Self::Output { + self + ProjectivePoint { inner: *rhs } + } +} + +impl AddAssign> for ProjectivePoint { + fn add_assign(&mut self, rhs: AffinePoint) { + self.as_mut_zkvm_point().add_assign(&rhs.inner); + } +} + +impl AddAssign<&AffinePoint> for ProjectivePoint { + fn add_assign(&mut self, rhs: &AffinePoint) { + self.as_mut_zkvm_point().add_assign(&rhs.inner); + } +} + +impl Sub> for ProjectivePoint { + type Output = ProjectivePoint; + + fn sub(self, rhs: AffinePoint) -> Self::Output { + self - ProjectivePoint { inner: rhs } + } +} + +impl Sub<&AffinePoint> for ProjectivePoint { + type Output = ProjectivePoint; + + fn sub(self, rhs: &AffinePoint) -> Self::Output { + self - ProjectivePoint { inner: *rhs } + } +} + +impl SubAssign> for ProjectivePoint { + fn sub_assign(&mut self, rhs: AffinePoint) { + let projective = ProjectivePoint { inner: rhs }.neg(); + + self.as_mut_zkvm_point() + .add_assign(projective.as_zkvm_point()); + } +} + +impl SubAssign<&AffinePoint> for ProjectivePoint { + fn sub_assign(&mut self, rhs: &AffinePoint) { + let projective = ProjectivePoint { inner: *rhs }.neg(); + + self.as_mut_zkvm_point() + .add_assign(projective.as_zkvm_point()); + } +} + +impl DefaultIsZeroes for ProjectivePoint {} + +impl ConditionallySelectable for ProjectivePoint { + fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self { + Self { + inner: AffinePoint::conditional_select(&a.inner, &b.inner, choice), + } + } +} + +impl ConstantTimeEq for ProjectivePoint { + fn ct_eq(&self, other: &Self) -> Choice { + self.inner.ct_eq(&other.inner) + } +} + +impl PartialEq for ProjectivePoint { + fn eq(&self, other: &Self) -> bool { + self.ct_eq(other).into() + } +} + +impl Eq for ProjectivePoint {} + +impl GroupEncoding for ProjectivePoint +where + FieldBytes: Copy, + C::FieldBytesSize: ModulusSize, + CompressedPoint: Copy, +{ + type Repr = CompressedPoint; + + fn from_bytes(bytes: &Self::Repr) -> CtOption { + as GroupEncoding>::from_bytes(bytes).map(Into::into) + } + + fn from_bytes_unchecked(bytes: &Self::Repr) -> CtOption { + // No unchecked conversion possible for compressed points. + Self::from_bytes(bytes) + } + + fn to_bytes(&self) -> Self::Repr { + self.inner.to_bytes() + } +} + +impl PrimeGroup for ProjectivePoint +where + FieldBytes: Copy, + C::FieldBytesSize: ModulusSize, + CompressedPoint: Copy, +{ +} + +/// The scalar field has prime order, so the cofactor is 1. +impl CofactorGroup for ProjectivePoint +where + FieldBytes: Copy, + C::FieldBytesSize: ModulusSize, + CompressedPoint: Copy, +{ + type Subgroup = Self; + + fn clear_cofactor(&self) -> Self { + *self + } + + fn into_subgroup(self) -> CtOption { + CtOption::new(self, Choice::from(1)) + } + + fn is_torsion_free(&self) -> Choice { + Choice::from(1) + } +} + +#[inline] +fn be_bytes_to_le_words>(mut bytes: T) -> [u64; 4] { + let bytes = bytes.as_mut(); + bytes.reverse(); + + let mut iter = bytes + .chunks(8) + .map(|b| u64::from_le_bytes(b.try_into().unwrap())); + core::array::from_fn(|_| iter.next().unwrap()) +} + +/// Convert big-endian bytes with the most significant bit first to little-endian bytes with the +/// least significant bit first. Panics: If the bytes have len > 32. +#[inline] +fn be_bytes_to_le_bits(be_bytes: &[u8]) -> [bool; 256] { + let mut bits = [false; 256]; + // Reverse the byte order to little-endian. + for (i, &byte) in be_bytes.iter().rev().enumerate() { + for j in 0..8 { + // Flip the bit order so the least significant bit is now the first bit of the chunk. + bits[i * 8 + j] = ((byte >> j) & 1) == 1; + } + } + bits +} diff --git a/sdk/patch-libs/src/ed25519.rs b/sdk/patch-libs/src/ed25519.rs index 3eea834e..ac45982a 100644 --- a/sdk/patch-libs/src/ed25519.rs +++ b/sdk/patch-libs/src/ed25519.rs @@ -1,30 +1,42 @@ use crate::{syscall_ed_add, utils::AffinePoint}; /// The number of limbs in [Ed25519AffinePoint]. -pub const N: usize = 16; +pub const N: usize = 8; /// An affine point on the Ed25519 curve. #[derive(Copy, Clone)] -#[repr(align(4))] -pub struct Ed25519AffinePoint(pub [u32; N]); +#[repr(align(8))] +pub struct Ed25519AffinePoint(pub [u64; N]); impl AffinePoint for Ed25519AffinePoint { /// The generator/base point for the Ed25519 curve. Reference: https://datatracker.ietf.org/doc/html/rfc7748#section-4.1 - const GENERATOR: [u32; N] = [ - 216936062, 3086116296, 2351951131, 1681893421, 3444223839, 2756123356, 3800373269, - 3284567716, 2518301344, 752319464, 3983256831, 1952656717, 3669724772, 3793645816, - 3665724614, 2969860233, + const GENERATOR: [u64; N] = [ + 13254768563189591678, + 7223677240904510747, + 11837459681205989215, + 14107110925517789205, + 3231187496542550688, + 8386596743812984063, + 16293584715996958308, + 12755452578091664582, ]; - fn new(limbs: [u32; N]) -> Self { + #[allow(deprecated)] + const GENERATOR_T: Self = Self(Self::GENERATOR); + + fn new(limbs: [u64; N]) -> Self { Self(limbs) } - fn limbs_ref(&self) -> &[u32; N] { + fn identity() -> Self { + Self::identity() + } + + fn limbs_ref(&self) -> &[u64; N] { &self.0 } - fn limbs_mut(&mut self) -> &mut [u32; N] { + fn limbs_mut(&mut self) -> &mut [u64; N] { &mut self.0 } @@ -36,6 +48,10 @@ impl AffinePoint for Ed25519AffinePoint { } } + fn is_identity(&self) -> bool { + self.0 == Self::IDENTITY + } + /// In Edwards curves, doubling is the same as adding a point to itself. fn double(&mut self) { let a = self.limbs_mut(); @@ -46,7 +62,7 @@ impl AffinePoint for Ed25519AffinePoint { } impl Ed25519AffinePoint { - const IDENTITY: [u32; N] = [0; N]; + const IDENTITY: [u64; N] = [0, 0, 0, 0, 1, 0, 0, 0]; pub fn identity() -> Self { Self(Self::IDENTITY) diff --git a/sdk/patch-libs/src/io.rs b/sdk/patch-libs/src/io.rs index acfe455a..3d83d604 100644 --- a/sdk/patch-libs/src/io.rs +++ b/sdk/patch-libs/src/io.rs @@ -24,6 +24,18 @@ pub const FD_EDDECOMPRESS: u32 = 8; /// The file descriptor for brevis coprocessor outputs. pub const FD_COPROCESSOR_OUTPUTS: u32 = 9; +/// The file descriptor through which to access `hook_fp_sqrt`. +pub const FD_FP_SQRT: u32 = 10; + +/// The file descriptor through which to access `hook_fp_inverse`. +pub const FD_FP_INV: u32 = 11; + +/// The file descriptor through which to access `hook_bls12_381_sqrt`. +pub const FD_BLS12_381_SQRT: u32 = 12; + +/// The file descriptor through which to access `hook_bls12_381_inverse`. +pub const FD_BLS12_381_INVERSE: u32 = 13; + /// A writer that writes to a file descriptor inside the zkVM. pub struct SyscallWriter { pub fd: u32, @@ -44,36 +56,67 @@ impl Write for SyscallWriter { } } +#[repr(C)] +pub struct ReadVecResult { + pub ptr: *mut u8, + pub len: usize, + pub capacity: usize, +} + /// Read a buffer from the input stream. /// -/// ### Examples -/// ```ignore -/// let data: Vec = pico_sdk::io::read_vec(); -/// ``` -pub fn read_vec() -> Vec { - // Round up to the nearest multiple of 4 so that the memory allocated is in whole words +/// The buffer is read into uninitialized memory. +fn read_vec_raw() -> ReadVecResult { + // Get the length of the input buffer. let len = unsafe { syscall_hint_len() }; - let capacity = len.div_ceil(4) * 4; - // Allocate a buffer of the required length that is 4 byte aligned - let layout = Layout::from_size_align(capacity, 4).expect("vec is too large"); - let ptr = unsafe { std::alloc::alloc(layout) }; + // If the length is u32::MAX, then the input stream is exhausted. + if len == usize::MAX { + return ReadVecResult { + ptr: std::ptr::null_mut(), + len: 0, + capacity: 0, + }; + } + + // Round up to multiple of 8 for whole-word alignment. + let capacity = len.div_ceil(8) * 8; - // SAFETY: - // 1. `ptr` was allocated using alloc - // 2. We assuume that the VM global allocator doesn't dealloc - // 3/6. Size is correct from above - // 4/5. Length is 0 - // 7. Layout::from_size_align already checks this - let mut vec = unsafe { Vec::from_raw_parts(ptr, 0, capacity) }; + // Allocate a buffer of the required length that is 8 byte aligned. + let layout = Layout::from_size_align(capacity, 8).expect("vec is too large"); - // Read the vec into uninitialized memory. The syscall assumes the memory is uninitialized, - // which should be true because the allocator does not dealloc, so a new alloc should be fresh. + // SAFETY: The layout was made through the checked constructor. + let ptr = unsafe { std::alloc::alloc(layout) }; + + // Read the vec into uninitialized memory. The syscall assumes the memory is + // uninitialized, which is true because the bump allocator does not dealloc, so a new + // alloc is always fresh. unsafe { syscall_hint_read(ptr, len); - vec.set_len(len); } - vec + + // Return the result. + ReadVecResult { ptr, len, capacity } +} + +/// Read a buffer from the input stream. +/// +/// ### Examples +/// ```ignore +/// let data: Vec = pico_sdk::io::read_vec(); +/// ``` +pub fn read_vec() -> Vec { + let ReadVecResult { ptr, len, capacity } = unsafe { read_vec_raw() }; + + if ptr.is_null() { + panic!( + "Tried to read from the input stream, but it was empty @ {} \n + Was the correct data written into Stdin?", + std::panic::Location::caller(), + ) + } + + unsafe { Vec::from_raw_parts(ptr, len, capacity) } } /// Read a deserializable object from the input stream. @@ -95,6 +138,45 @@ pub fn read() -> T { bincode::deserialize(&vec).expect("deserialization failed") } +/// Commit a serializable object to the public values stream. +/// +/// ### Examples +/// ```ignore +/// use serde::{Deserialize, Serialize}; +/// +/// #[derive(Serialize, Deserialize)] +/// struct MyStruct { +/// a: u32, +/// b: u32, +/// } +/// +/// let data = MyStruct { +/// a: 1, +/// b: 2, +/// }; +/// pico_sdk::io::commit(&data); +/// ``` +pub fn commit(value: &T) { + let writer = SyscallWriter { + fd: FD_PUBLIC_VALUES, + }; + bincode::serialize_into(writer, value).expect("serialization failed"); +} + +/// Commit bytes to the public values stream. +/// +/// ### Examples +/// ```ignore +/// let data = vec![1, 2, 3, 4]; +/// pico_sdk::io::commit_slice(&data); +/// ``` +pub fn commit_slice(buf: &[u8]) { + let mut my_writer = SyscallWriter { + fd: FD_PUBLIC_VALUES, + }; + my_writer.write_all(buf).unwrap(); +} + /// Hint a serializable object to the hint stream. /// /// ### Examples diff --git a/sdk/patch-libs/src/lib.rs b/sdk/patch-libs/src/lib.rs index eff8c678..0ebb26ee 100644 --- a/sdk/patch-libs/src/lib.rs +++ b/sdk/patch-libs/src/lib.rs @@ -1,6 +1,7 @@ //! Wrapper syscall API for the Pico patches. pub mod bls12381; pub mod bn254; +pub mod ecdsa; pub mod ed25519; pub mod io; pub mod secp256k1; @@ -27,52 +28,52 @@ extern "C" { pub fn syscall_read(fd: u32, read_buf: *mut u8, nbytes: usize); /// Executes the SHA-256 extend operation on the given word array. - pub fn syscall_sha256_extend(w: *mut [u32; 64]); + pub fn syscall_sha256_extend(w: *mut [u64; 64]); /// Executes the SHA-256 compress operation on the given word array and a given state. - pub fn syscall_sha256_compress(w: *mut [u32; 64], state: *mut [u32; 8]); + pub fn syscall_sha256_compress(w: *mut [u64; 64], state: *mut [u64; 8]); /// Executes an Ed25519 curve addition on the given points. - pub fn syscall_ed_add(p: *mut [u32; 16], q: *const [u32; 16]); + pub fn syscall_ed_add(p: *mut [u64; 8], q: *const [u64; 8]); /// Executes an Ed25519 curve decompression on the given point. - pub fn syscall_ed_decompress(point: &mut [u8; 64]); + pub fn syscall_ed_decompress(point: &mut [u64; 8]); - /// Executes an Secp256k1 curve addition on the given points. - pub fn syscall_secp256k1_add(p: *mut [u32; 16], q: *const [u32; 16]); + /// Executes an Sepc256k1 curve addition on the given points. + pub fn syscall_secp256k1_add(p: *mut [u64; 8], q: *const [u64; 8]); /// Executes an Secp256k1 curve doubling on the given point. - pub fn syscall_secp256k1_double(p: *mut [u32; 16]); + pub fn syscall_secp256k1_double(p: *mut [u64; 8]); /// Executes an Secp256k1 curve decompression on the given point. - pub fn syscall_secp256k1_decompress(point: &mut [u8; 64], is_odd: bool); + pub fn syscall_secp256k1_decompress(point: &mut [u64; 8], is_odd: bool); /// Executes an Secp256r1 curve addition on the given points. - pub fn syscall_secp256r1_add(p: *mut [u32; 16], q: *const [u32; 16]); + pub fn syscall_secp256r1_add(p: *mut [u64; 8], q: *const [u64; 8]); /// Executes an Secp256r1 curve doubling on the given point. - pub fn syscall_secp256r1_double(p: *mut [u32; 16]); + pub fn syscall_secp256r1_double(p: *mut [u64; 8]); /// Executes an Secp256r1 curve decompression on the given point. - pub fn syscall_secp256r1_decompress(point: &mut [u8; 64], is_odd: bool); + pub fn syscall_secp256r1_decompress(point: &mut [u64; 8], is_odd: bool); /// Executes a Bn254 curve addition on the given points. - pub fn syscall_bn254_add(p: *mut [u32; 16], q: *const [u32; 16]); + pub fn syscall_bn254_add(p: *mut [u64; 8], q: *const [u64; 8]); /// Executes a Bn254 curve doubling on the given point. - pub fn syscall_bn254_double(p: *mut [u32; 16]); + pub fn syscall_bn254_double(p: *mut [u64; 8]); /// Executes a BLS12-381 curve addition on the given points. - pub fn syscall_bls12381_add(p: *mut [u32; 24], q: *const [u32; 24]); + pub fn syscall_bls12381_add(p: *mut [u64; 12], q: *const [u64; 12]); /// Executes a BLS12-381 curve doubling on the given point. - pub fn syscall_bls12381_double(p: *mut [u32; 24]); + pub fn syscall_bls12381_double(p: *mut [u64; 12]); /// Executes the Keccak-256 permutation on the given state. pub fn syscall_keccak_permute(state: *mut [u64; 25]); /// Executes an uint256 multiplication on the given inputs. - pub fn syscall_uint256_mulmod(x: *mut [u32; 8], y: *const [u32; 8]); + pub fn syscall_uint256_mulmod(x: *mut [u64; 4], y: *const [u64; 4]); /// Enters unconstrained mode. pub fn syscall_enter_unconstrained() -> bool; @@ -93,53 +94,54 @@ extern "C" { pub fn sys_alloc_aligned(bytes: usize, align: usize) -> *mut u8; /// Decompresses a BLS12-381 point. - pub fn syscall_bls12381_decompress(point: &mut [u8; 96], is_odd: bool); + pub fn syscall_bls12381_decompress(point: &mut [u64; 12], is_odd: bool); /// Computes a big integer operation with a modulus. pub fn sys_bigint( - result: *mut [u32; 8], - op: u32, - x: *const [u32; 8], - y: *const [u32; 8], - modulus: *const [u32; 8], + result: *mut [u64; 4], + op: u64, + x: *const [u64; 4], + y: *const [u64; 4], + modulus: *const [u64; 4], ); /// Executes a BLS12-381 field addition on the given inputs. - pub fn syscall_bls12381_fp_addmod(p: *mut u32, q: *const u32); + pub fn syscall_bls12381_fp_addmod(p: *mut u64, q: *const u64); /// Executes a BLS12-381 field subtraction on the given inputs. - pub fn syscall_bls12381_fp_submod(p: *mut u32, q: *const u32); + pub fn syscall_bls12381_fp_submod(p: *mut u64, q: *const u64); /// Executes a BLS12-381 field multiplication on the given inputs. - pub fn syscall_bls12381_fp_mulmod(p: *mut u32, q: *const u32); + pub fn syscall_bls12381_fp_mulmod(p: *mut u64, q: *const u64); /// Executes a BLS12-381 Fp2 addition on the given inputs. - pub fn syscall_bls12381_fp2_addmod(p: *mut u32, q: *const u32); + pub fn syscall_bls12381_fp2_addmod(p: *mut u64, q: *const u64); /// Executes a BLS12-381 Fp2 subtraction on the given inputs. - pub fn syscall_bls12381_fp2_submod(p: *mut u32, q: *const u32); + pub fn syscall_bls12381_fp2_submod(p: *mut u64, q: *const u64); /// Executes a BLS12-381 Fp2 multiplication on the given inputs. - pub fn syscall_bls12381_fp2_mulmod(p: *mut u32, q: *const u32); + pub fn syscall_bls12381_fp2_mulmod(p: *mut u64, q: *const u64); /// Executes a BN254 field addition on the given inputs. - pub fn syscall_bn254_fp_addmod(p: *mut u32, q: *const u32); + pub fn syscall_bn254_fp_addmod(p: *mut u64, q: *const u64); /// Executes a BN254 field subtraction on the given inputs. - pub fn syscall_bn254_fp_submod(p: *mut u32, q: *const u32); + pub fn syscall_bn254_fp_submod(p: *mut u64, q: *const u64); /// Executes a BN254 field multiplication on the given inputs. - pub fn syscall_bn254_fp_mulmod(p: *mut u32, q: *const u32); + pub fn syscall_bn254_fp_mulmod(p: *mut u64, q: *const u64); /// Executes a BN254 Fp2 addition on the given inputs. - pub fn syscall_bn254_fp2_addmod(p: *mut u32, q: *const u32); + pub fn syscall_bn254_fp2_addmod(p: *mut u64, q: *const u64); /// Executes a BN254 Fp2 subtraction on the given inputs. - pub fn syscall_bn254_fp2_submod(p: *mut u32, q: *const u32); + pub fn syscall_bn254_fp2_submod(p: *mut u64, q: *const u64); /// Executes a BN254 Fp2 multiplication on the given inputs. - pub fn syscall_bn254_fp2_mulmod(p: *mut u32, q: *const u32); + pub fn syscall_bn254_fp2_mulmod(p: *mut u64, q: *const u64); + /* /// Executes a Secp256k1 field addition on the given inputs. pub fn syscall_secp256k1_fp_addmod(p: *mut u32, q: *const u32); @@ -148,8 +150,8 @@ extern "C" { /// Executes a Secp256k1 field multiplication on the given inputs. pub fn syscall_secp256k1_fp_mulmod(p: *mut u32, q: *const u32); + */ /// Executes an poseidon2 permute on the given inputs. pub fn syscall_poseidon2_permute(x: *const [u32; 16], y: *mut [u32; 16]); - } diff --git a/sdk/patch-libs/src/secp256k1.rs b/sdk/patch-libs/src/secp256k1.rs index 0c5f65ed..f7c0217f 100644 --- a/sdk/patch-libs/src/secp256k1.rs +++ b/sdk/patch-libs/src/secp256k1.rs @@ -4,11 +4,11 @@ use crate::{ }; /// The number of limbs in [Secp256k1Point]. -pub const N: usize = 16; +pub const N: usize = 8; /// An affine point on the Secp256k1 curve. -#[derive(Copy, Clone)] -#[repr(align(4))] +#[derive(Copy, Clone, Debug)] +#[repr(align(8))] pub struct Secp256k1Point(pub WeierstrassPoint); impl WeierstrassAffinePoint for Secp256k1Point { @@ -23,34 +23,46 @@ impl WeierstrassAffinePoint for Secp256k1Point { impl AffinePoint for Secp256k1Point { /// The values are taken from https://en.bitcoin.it/wiki/Secp256k1. - const GENERATOR: [u32; N] = [ - 385357720, 1509065051, 768485593, 43777243, 3464956679, 1436574357, 4191992748, 2042521214, - 4212184248, 2621952143, 2793755673, 4246189128, 235997352, 1571093500, 648266853, - 1211816567, + const GENERATOR: [u64; N] = [ + 6481385041966929816, + 188021827762530521, + 6170039885052185351, + 8772561819708210092, + 11261198710074299576, + 18237243440184513561, + 6747795201694173352, + 5204712524664259685, ]; - fn new(limbs: [u32; N]) -> Self { + #[allow(deprecated)] + const GENERATOR_T: Self = Self(WeierstrassPoint::Affine(Self::GENERATOR)); + + fn new(limbs: [u64; N]) -> Self { Self(WeierstrassPoint::Affine(limbs)) } - fn limbs_ref(&self) -> &[u32; N] { + fn identity() -> Self { + Self::infinity() + } + + fn is_identity(&self) -> bool { + self.is_infinity() + } + + fn limbs_ref(&self) -> &[u64; N] { match &self.0 { WeierstrassPoint::Infinity => panic!("Infinity point has no limbs"), WeierstrassPoint::Affine(limbs) => limbs, } } - fn limbs_mut(&mut self) -> &mut [u32; N] { + fn limbs_mut(&mut self) -> &mut [u64; N] { match &mut self.0 { WeierstrassPoint::Infinity => panic!("Infinity point has no limbs"), WeierstrassPoint::Affine(limbs) => limbs, } } - fn complete_add_assign(&mut self, other: &Self) { - self.weierstrass_add_assign(other); - } - fn add_assign(&mut self, other: &Self) { let a = self.limbs_mut(); let b = other.limbs_ref(); @@ -59,6 +71,10 @@ impl AffinePoint for Secp256k1Point { } } + fn complete_add_assign(&mut self, other: &Self) { + self.weierstrass_add_assign(other); + } + fn double(&mut self) { match &mut self.0 { WeierstrassPoint::Infinity => (), diff --git a/sdk/patch-libs/src/secp256r1.rs b/sdk/patch-libs/src/secp256r1.rs index aa2faf1b..9d316f6f 100644 --- a/sdk/patch-libs/src/secp256r1.rs +++ b/sdk/patch-libs/src/secp256r1.rs @@ -4,11 +4,11 @@ use crate::{ }; /// The number of limbs in [Secp256r1Point]. -pub const N: usize = 16; +pub const N: usize = 8; -/// An affine point on the Secp256r1 curve. -#[derive(Copy, Clone)] -#[repr(align(4))] +/// An affine point on the Secp256k1 curve. +#[derive(Copy, Clone, Debug)] +#[repr(align(8))] pub struct Secp256r1Point(pub WeierstrassPoint); impl WeierstrassAffinePoint for Secp256r1Point { @@ -22,34 +22,46 @@ impl WeierstrassAffinePoint for Secp256r1Point { } impl AffinePoint for Secp256r1Point { - const GENERATOR: [u32; N] = [ - 3633889942, 4104206661, 770388896, 1996717441, 1671708914, 4173129445, 3777774151, - 1796723186, 935285237, 3417718888, 1798397646, 734933847, 2081398294, 2397563722, - 4263149467, 1340293858, + const GENERATOR: [u64; N] = [ + 17627433388654248598, + 8575836109218198432, + 17923454489921339634, + 7716867327612699207, + 14678990851816772085, + 3156516839386865358, + 10297457778147434006, + 5756518291402817435, ]; - fn new(limbs: [u32; N]) -> Self { + #[allow(deprecated)] + const GENERATOR_T: Self = Self(WeierstrassPoint::Affine(Self::GENERATOR)); + + fn new(limbs: [u64; N]) -> Self { Self(WeierstrassPoint::Affine(limbs)) } - fn limbs_ref(&self) -> &[u32; N] { + fn identity() -> Self { + Self::infinity() + } + + fn is_identity(&self) -> bool { + self.is_infinity() + } + + fn limbs_ref(&self) -> &[u64; N] { match &self.0 { WeierstrassPoint::Infinity => panic!("Infinity point has no limbs"), WeierstrassPoint::Affine(limbs) => limbs, } } - fn limbs_mut(&mut self) -> &mut [u32; N] { + fn limbs_mut(&mut self) -> &mut [u64; N] { match &mut self.0 { WeierstrassPoint::Infinity => panic!("Infinity point has no limbs"), WeierstrassPoint::Affine(limbs) => limbs, } } - fn complete_add_assign(&mut self, other: &Self) { - self.weierstrass_add_assign(other); - } - fn add_assign(&mut self, other: &Self) { let a = self.limbs_mut(); let b = other.limbs_ref(); @@ -58,6 +70,10 @@ impl AffinePoint for Secp256r1Point { } } + fn complete_add_assign(&mut self, other: &Self) { + self.weierstrass_add_assign(other); + } + fn double(&mut self) { match &mut self.0 { WeierstrassPoint::Infinity => (), diff --git a/sdk/patch-libs/src/utils.rs b/sdk/patch-libs/src/utils.rs index 94832cc8..8fb5e930 100644 --- a/sdk/patch-libs/src/utils.rs +++ b/sdk/patch-libs/src/utils.rs @@ -1,24 +1,33 @@ pub trait AffinePoint: Clone + Sized { /// The generator. - const GENERATOR: [u32; N]; + #[deprecated = "This const will have the `Self` type in the next major version."] + const GENERATOR: [u64; N]; + + const GENERATOR_T: Self; /// Creates a new [`AffinePoint`] from the given limbs. - fn new(limbs: [u32; N]) -> Self; + fn new(limbs: [u64; N]) -> Self; + + /// Creates a new [`AffinePoint`] that corresponds to the identity point. + fn identity() -> Self; /// Returns a reference to the limbs. - fn limbs_ref(&self) -> &[u32; N]; + fn limbs_ref(&self) -> &[u64; N]; - /// Returns a mutable reference to the limbs. If the point is the infinity point, this will panic. - fn limbs_mut(&mut self) -> &mut [u32; N]; + /// Returns a mutable reference to the limbs. If the point is the infinity point, this will + /// panic. + fn limbs_mut(&mut self) -> &mut [u64; N]; + + fn is_identity(&self) -> bool; /// Creates a new [`AffinePoint`] from the given x and y coordinates. /// /// The bytes are the concatenated little endian representations of the coordinates. fn from(x: &[u8], y: &[u8]) -> Self { - debug_assert!(x.len() == N * 2); - debug_assert!(y.len() == N * 2); + debug_assert!(x.len() == N * 4); + debug_assert!(y.len() == N * 4); - let mut limbs = [0u32; N]; + let mut limbs = [0u64; N]; let x = bytes_to_words_le(x); let y = bytes_to_words_le(y); @@ -40,7 +49,7 @@ pub trait AffinePoint: Clone + Sized { /// Creates a new [`AffinePoint`] from the given bytes in big endian. fn to_le_bytes(&self) -> Vec { let le_bytes = words_to_bytes_le(self.limbs_ref()); - debug_assert!(le_bytes.len() == N * 4); + debug_assert!(le_bytes.len() == N * 8); le_bytes } @@ -48,7 +57,8 @@ pub trait AffinePoint: Clone + Sized { fn add_assign(&mut self, other: &Self); /// Adds the given [`AffinePoint`] to `self`. Can be optionally overridden to use a different - /// implementation of addition in multi-scalar multiplication, which is used in secp256k1 recovery. + /// implementation of addition in multi-scalar multiplication, which is used in secp256k1 + /// recovery. fn complete_add_assign(&mut self, other: &Self) { self.add_assign(other); } @@ -57,32 +67,22 @@ pub trait AffinePoint: Clone + Sized { fn double(&mut self); /// Multiplies `self` by the given scalar. - fn mul_assign(&mut self, scalar: &[u32]) -> Result<(), MulAssignError> { - debug_assert!(scalar.len() == N / 2); + fn mul_assign(&mut self, scalar: &[u64]) { + debug_assert_eq!(scalar.len(), N / 2); - let mut res: Option = None; + let mut res: Self = Self::identity(); let mut temp = self.clone(); - let scalar_is_zero = scalar.iter().all(|&words| words == 0); - if scalar_is_zero { - return Err(MulAssignError::ScalarIsZero); - } - for &words in scalar.iter() { - for i in 0..32 { + for i in 0..u64::BITS { if (words >> i) & 1 == 1 { - match res.as_mut() { - Some(res) => res.add_assign(&temp), - None => res = Some(temp.clone()), - }; + res.complete_add_assign(&temp); } - temp.double(); } } - *self = res.unwrap(); - Ok(()) + *self = res; } /// Performs multi-scalar multiplication (MSM) on slices of bit vectors and points. Note: @@ -92,28 +92,20 @@ pub trait AffinePoint: Clone + Sized { a: Self, b_bits_le: &[bool], b: Self, - ) -> Option { + ) -> Self { // The length of the bit vectors must be the same. debug_assert!(a_bits_le.len() == b_bits_le.len()); - let mut res: Option = None; + let mut res: Self = Self::identity(); let mut temp_a = a.clone(); let mut temp_b = b.clone(); for (a_bit, b_bit) in a_bits_le.iter().zip(b_bits_le.iter()) { if *a_bit { - match res.as_mut() { - Some(res) => res.complete_add_assign(&temp_a), - None => res = Some(temp_a.clone()), - }; + res.complete_add_assign(&temp_a); } - if *b_bit { - match res.as_mut() { - Some(res) => res.complete_add_assign(&temp_b), - None => res = Some(temp_b.clone()), - }; + res.complete_add_assign(&temp_b); } - temp_a.double(); temp_b.double(); } @@ -128,26 +120,26 @@ pub enum MulAssignError { } /// Converts a slice of words to a byte array in little endian. -pub fn words_to_bytes_le(words: &[u32]) -> Vec { +pub fn words_to_bytes_le(words: &[u64]) -> Vec { words .iter() - .flat_map(|word| word.to_le_bytes().to_vec()) + .flat_map(|word| word.to_le_bytes()) .collect::>() } /// Converts a byte array in little endian to a slice of words. -pub fn bytes_to_words_le(bytes: &[u8]) -> Vec { +pub fn bytes_to_words_le(bytes: &[u8]) -> Vec { bytes - .chunks_exact(4) - .map(|chunk| u32::from_le_bytes(chunk.try_into().unwrap())) + .chunks_exact(8) + .map(|chunk| u64::from_le_bytes(chunk.try_into().unwrap())) .collect::>() } -#[derive(Copy, Clone)] +#[derive(Copy, Clone, Debug)] /// A representation of a point on a Weierstrass curve. pub enum WeierstrassPoint { Infinity, - Affine([u32; N]), + Affine([u64; N]), } /// A trait for affine points on Weierstrass curves. diff --git a/sdk/sdk/src/lib.rs b/sdk/sdk/src/lib.rs index aa99360c..4205e4c8 100644 --- a/sdk/sdk/src/lib.rs +++ b/sdk/sdk/src/lib.rs @@ -77,11 +77,10 @@ mod zkvm { syscall_halt(0); } - static STACK_TOP: u32 = 0x0020_0400; - core::arch::global_asm!(include_str!("memset.s")); core::arch::global_asm!(include_str!("memcpy.s")); + static STACK_TOP: u64 = 0x7800_0000; core::arch::global_asm!( r#" .section .text._start; @@ -92,7 +91,7 @@ mod zkvm { la gp, __global_pointer$; .option pop; la sp, {0} - lw sp, 0(sp) + ld sp, 0(sp) call __start; "#, sym STACK_TOP diff --git a/sdk/sdk/src/memcpy.s b/sdk/sdk/src/memcpy.s index 1735cae5..0d06805c 100644 --- a/sdk/sdk/src/memcpy.s +++ b/sdk/sdk/src/memcpy.s @@ -1,13 +1,3 @@ -// This is musl-libc commit 37e18b7bf307fa4a8c745feebfcba54a0ba74f30: -// -// src/string/memcpy.c -// -// This was compiled into assembly with: -// -// clang-14 -target riscv32 -march=rv32im -O3 -S memcpy.c -nostdlib -fno-builtin -funroll-loops -// -// and labels manually updated to not conflict. -// // musl as a whole is licensed under the following standard MIT license: // // ---------------------------------------------------------------------- @@ -179,8 +169,8 @@ // interest of source tree size. // // In addition, permission is hereby granted for all public header files -// (include/* and arch/* /bits/* ) and crt files intended to be linked into -// applications (crt/*, ldso/dlstart.c, and arch/* /crt_arch.h) to omit +// (include/* and arch/*/bits/*) and crt files intended to be linked into +// applications (crt/*, ldso/dlstart.c, and arch/*/crt_arch.h) to omit // the copyright notice and permission notice otherwise required by the // license, and to use these files without any requirement of // attribution. These files include substantial contributions from: @@ -201,298 +191,590 @@ // negated the permissions granted in the license. In the spirit of // permissive licensing, and of not having licensing issues being an // obstacle to adoption, that text has been removed. - .text .attribute 4, 16 - .attribute 5, "rv32im" - .file "musl_memcpy.c" - .globl memcpy + .attribute 5, "rv64i2p1_m2p0_zmmul1p0" + .file "memcpy.c" + .text + .globl memcpy # -- Begin function memcpy .p2align 2 .type memcpy,@function -memcpy: - andi a3, a1, 3 - seqz a3, a3 - seqz a4, a2 - or a3, a3, a4 - bnez a3, .LBBmemcpy0_11 - addi a5, a1, 1 - mv a6, a0 -.LBBmemcpy0_2: - lb a7, 0(a1) +memcpy: # @memcpy +# %bb.0: + andi a3, a1, 7 + beqz a3, .LBBmemcpy0_16 +# %bb.1: + beqz a2, .LBBmemcpy0_5 +# %bb.2: addi a4, a1, 1 - addi a3, a6, 1 - sb a7, 0(a6) - addi a2, a2, -1 - andi a1, a5, 3 - snez a1, a1 - snez a6, a2 - and a7, a1, a6 - addi a5, a5, 1 - mv a1, a4 - mv a6, a3 - bnez a7, .LBBmemcpy0_2 - andi a1, a3, 3 - beqz a1, .LBBmemcpy0_12 -.LBBmemcpy0_4: - li a5, 32 - bltu a2, a5, .LBBmemcpy0_26 - li a5, 3 - beq a1, a5, .LBBmemcpy0_19 - li a5, 2 - beq a1, a5, .LBBmemcpy0_22 li a5, 1 - bne a1, a5, .LBBmemcpy0_26 - lw a5, 0(a4) - sb a5, 0(a3) - srli a1, a5, 8 - sb a1, 1(a3) - srli a6, a5, 16 - addi a1, a3, 3 - sb a6, 2(a3) - addi a2, a2, -3 - addi a3, a4, 16 - li a4, 16 + mv a3, a0 +.LBBmemcpy0_3: # =>This Inner Loop Header: Depth=1 + lbu a7, 0(a1) + mv a6, a2 + addi a1, a1, 1 + andi t0, a4, 7 + sb a7, 0(a3) + addi a3, a3, 1 + addi a2, a2, -1 + beqz t0, .LBBmemcpy0_6 +# %bb.4: # in Loop: Header=BB0_3 Depth=1 + addi a4, a4, 1 + bne a6, a5, .LBBmemcpy0_3 + j .LBBmemcpy0_6 +.LBBmemcpy0_5: + mv a3, a0 +.LBBmemcpy0_6: + andi a4, a3, 7 + beqz a4, .LBBmemcpy0_17 +.LBBmemcpy0_7: + li a5, 64 + bgeu a2, a5, .LBBmemcpy0_12 +# %bb.8: + li a4, 32 + bgeu a2, a4, .LBBmemcpy0_44 .LBBmemcpy0_9: - lw a6, -12(a3) - srli a5, a5, 24 - slli a7, a6, 8 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 24 - slli a6, t0, 8 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 24 - slli t0, a7, 8 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 24 - slli a7, a5, 8 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_9 - addi a4, a3, -13 - j .LBBmemcpy0_25 + andi a4, a2, 16 + bnez a4, .LBBmemcpy0_45 +.LBBmemcpy0_10: + andi a4, a2, 8 + bnez a4, .LBBmemcpy0_46 .LBBmemcpy0_11: - mv a3, a0 - mv a4, a1 - andi a1, a3, 3 - bnez a1, .LBBmemcpy0_4 + andi a4, a2, 4 + bnez a4, .LBBmemcpy0_47 + j .LBBmemcpy0_48 .LBBmemcpy0_12: - li a1, 16 - bltu a2, a1, .LBBmemcpy0_15 - li a1, 15 -.LBBmemcpy0_14: + addi a4, a4, -1 + slli a4, a4, 2 + lui a5, %hi(.LJTI0_0) + addi a5, a5, %lo(.LJTI0_0) + add a4, a4, a5 lw a5, 0(a4) - lw a6, 4(a4) - lw a7, 8(a4) - lw t0, 12(a4) - sw a5, 0(a3) - sw a6, 4(a3) - sw a7, 8(a3) - sw t0, 12(a3) - addi a4, a4, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a1, a2, .LBBmemcpy0_14 -.LBBmemcpy0_15: - andi a1, a2, 8 - beqz a1, .LBBmemcpy0_17 - lw a1, 0(a4) - lw a5, 4(a4) - sw a1, 0(a3) - sw a5, 4(a3) - addi a3, a3, 8 - addi a4, a4, 8 + ld a4, 0(a1) + jr a5 +.LBBmemcpy0_13: + srli a5, a4, 8 + srli a6, a4, 16 + srli a7, a4, 24 + srli t0, a4, 32 + srli t1, a4, 40 + sb a4, 0(a3) + sb a5, 1(a3) + sb a6, 2(a3) + sb a7, 3(a3) + srli a5, a4, 48 + addi a2, a2, -7 + sb t0, 4(a3) + sb t1, 5(a3) + sb a5, 6(a3) + addi a3, a3, 7 + addi a1, a1, 32 + li a5, 32 +.LBBmemcpy0_14: # =>This Inner Loop Header: Depth=1 + srli a6, a4, 56 + ld a7, -24(a1) + ld t0, -16(a1) + ld t1, -8(a1) + ld a4, 0(a1) + slli t2, a7, 8 + srli a7, a7, 56 + or a6, t2, a6 + slli t2, t0, 8 + srli t0, t0, 56 + or a7, t2, a7 + slli t2, t1, 8 + srli t1, t1, 56 + or t0, t2, t0 + slli t2, a4, 8 + or t1, t2, t1 + addi a2, a2, -32 + sd a6, 0(a3) + sd a7, 8(a3) + sd t0, 16(a3) + sd t1, 24(a3) + addi a3, a3, 32 + addi a1, a1, 32 + bltu a5, a2, .LBBmemcpy0_14 +# %bb.15: + addi a1, a1, -25 + li a4, 32 + bltu a2, a4, .LBBmemcpy0_9 + j .LBBmemcpy0_44 +.LBBmemcpy0_16: + mv a3, a0 + andi a4, a0, 7 + bnez a4, .LBBmemcpy0_7 .LBBmemcpy0_17: - andi a1, a2, 4 - beqz a1, .LBBmemcpy0_30 - lw a1, 0(a4) - sw a1, 0(a3) - addi a3, a3, 4 - addi a4, a4, 4 - j .LBBmemcpy0_30 -.LBBmemcpy0_19: - lw a5, 0(a4) - addi a1, a3, 1 - sb a5, 0(a3) - addi a2, a2, -1 - addi a3, a4, 16 - li a4, 18 + li a4, 32 + bltu a2, a4, .LBBmemcpy0_20 +# %bb.18: + li a4, 31 +.LBBmemcpy0_19: # =>This Inner Loop Header: Depth=1 + ld a5, 0(a1) + ld a6, 8(a1) + ld a7, 16(a1) + ld t0, 24(a1) + addi a1, a1, 32 + addi a2, a2, -32 + sd a5, 0(a3) + sd a6, 8(a3) + sd a7, 16(a3) + sd t0, 24(a3) + addi a3, a3, 32 + bltu a4, a2, .LBBmemcpy0_19 .LBBmemcpy0_20: - lw a6, -12(a3) - srli a5, a5, 8 - slli a7, a6, 24 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 8 - slli a6, t0, 24 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 8 - slli t0, a7, 24 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 8 - slli a7, a5, 24 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 - addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_20 - addi a4, a3, -15 - j .LBBmemcpy0_25 + li a4, 16 + bgeu a2, a4, .LBBmemcpy0_23 +# %bb.21: + andi a4, a2, 8 + bnez a4, .LBBmemcpy0_24 .LBBmemcpy0_22: - lw a5, 0(a4) - sb a5, 0(a3) - srli a6, a5, 8 - addi a1, a3, 2 - sb a6, 1(a3) - addi a2, a2, -2 - addi a3, a4, 16 - li a4, 17 + andi a4, a2, 4 + bnez a4, .LBBmemcpy0_25 + j .LBBmemcpy0_48 .LBBmemcpy0_23: - lw a6, -12(a3) - srli a5, a5, 16 - slli a7, a6, 16 - lw t0, -8(a3) - or a5, a7, a5 - sw a5, 0(a1) - srli a5, a6, 16 - slli a6, t0, 16 - lw a7, -4(a3) - or a5, a6, a5 - sw a5, 4(a1) - srli a6, t0, 16 - slli t0, a7, 16 - lw a5, 0(a3) - or a6, t0, a6 - sw a6, 8(a1) - srli a6, a7, 16 - slli a7, a5, 16 - or a6, a7, a6 - sw a6, 12(a1) - addi a1, a1, 16 - addi a2, a2, -16 + ld a4, 0(a1) + ld a5, 8(a1) + sd a4, 0(a3) + sd a5, 8(a3) addi a3, a3, 16 - bltu a4, a2, .LBBmemcpy0_23 - addi a4, a3, -14 + addi a1, a1, 16 + andi a4, a2, 8 + beqz a4, .LBBmemcpy0_22 +.LBBmemcpy0_24: + ld a4, 0(a1) + addi a1, a1, 8 + sd a4, 0(a3) + addi a3, a3, 8 + andi a4, a2, 4 + beqz a4, .LBBmemcpy0_48 .LBBmemcpy0_25: - mv a3, a1 + lw a4, 0(a1) + addi a1, a1, 4 + sw a4, 0(a3) + addi a3, a3, 4 + j .LBBmemcpy0_48 .LBBmemcpy0_26: - andi a1, a2, 16 - bnez a1, .LBBmemcpy0_35 - andi a1, a2, 8 - bnez a1, .LBBmemcpy0_36 -.LBBmemcpy0_28: - andi a1, a2, 4 - beqz a1, .LBBmemcpy0_30 + srli a5, a4, 8 + srli a6, a4, 16 + addi a2, a2, -3 + sb a4, 0(a3) + sb a5, 1(a3) + sb a6, 2(a3) + addi a3, a3, 3 + addi a1, a1, 32 + li a5, 36 +.LBBmemcpy0_27: # =>This Inner Loop Header: Depth=1 + srli a6, a4, 24 + ld a7, -24(a1) + ld t0, -16(a1) + ld t1, -8(a1) + ld a4, 0(a1) + slli t2, a7, 40 + srli a7, a7, 24 + or a6, t2, a6 + slli t2, t0, 40 + srli t0, t0, 24 + or a7, t2, a7 + slli t2, t1, 40 + srli t1, t1, 24 + or t0, t2, t0 + slli t2, a4, 40 + or t1, t2, t1 + addi a2, a2, -32 + sd a6, 0(a3) + sd a7, 8(a3) + sd t0, 16(a3) + sd t1, 24(a3) + addi a3, a3, 32 + addi a1, a1, 32 + bltu a5, a2, .LBBmemcpy0_27 +# %bb.28: + addi a1, a1, -29 + li a4, 32 + bltu a2, a4, .LBBmemcpy0_9 + j .LBBmemcpy0_44 .LBBmemcpy0_29: - lb a1, 0(a4) - lb a5, 1(a4) - lb a6, 2(a4) - sb a1, 0(a3) + srli a5, a4, 8 + srli a6, a4, 16 + srli a7, a4, 24 + srli t0, a4, 32 + addi a2, a2, -5 + sb a4, 0(a3) sb a5, 1(a3) - lb a1, 3(a4) sb a6, 2(a3) - addi a4, a4, 4 - addi a5, a3, 4 - sb a1, 3(a3) - mv a3, a5 -.LBBmemcpy0_30: - andi a1, a2, 2 - bnez a1, .LBBmemcpy0_33 - andi a1, a2, 1 - bnez a1, .LBBmemcpy0_34 + sb a7, 3(a3) + sb t0, 4(a3) + addi a3, a3, 5 + addi a1, a1, 32 + li a5, 34 +.LBBmemcpy0_30: # =>This Inner Loop Header: Depth=1 + srli a6, a4, 40 + ld a7, -24(a1) + ld t0, -16(a1) + ld t1, -8(a1) + ld a4, 0(a1) + slli t2, a7, 24 + srli a7, a7, 40 + or a6, t2, a6 + slli t2, t0, 24 + srli t0, t0, 40 + or a7, t2, a7 + slli t2, t1, 24 + srli t1, t1, 40 + or t0, t2, t0 + slli t2, a4, 24 + or t1, t2, t1 + addi a2, a2, -32 + sd a6, 0(a3) + sd a7, 8(a3) + sd t0, 16(a3) + sd t1, 24(a3) + addi a3, a3, 32 + addi a1, a1, 32 + bltu a5, a2, .LBBmemcpy0_30 +# %bb.31: + addi a1, a1, -27 + li a4, 32 + bltu a2, a4, .LBBmemcpy0_9 + j .LBBmemcpy0_44 .LBBmemcpy0_32: - ret -.LBBmemcpy0_33: - lb a1, 0(a4) - lb a5, 1(a4) - sb a1, 0(a3) - addi a4, a4, 2 - addi a1, a3, 2 + srli a5, a4, 8 + srli a6, a4, 16 + srli a7, a4, 24 + addi a2, a2, -4 + sb a4, 0(a3) sb a5, 1(a3) - mv a3, a1 - andi a1, a2, 1 - beqz a1, .LBBmemcpy0_32 -.LBBmemcpy0_34: - lb a1, 0(a4) - sb a1, 0(a3) - ret + sb a6, 2(a3) + sb a7, 3(a3) + addi a3, a3, 4 + addi a1, a1, 32 + li a5, 35 +.LBBmemcpy0_33: # =>This Inner Loop Header: Depth=1 + srli a6, a4, 32 + ld a7, -24(a1) + ld t0, -16(a1) + ld t1, -8(a1) + ld a4, 0(a1) + slli t2, a7, 32 + srli a7, a7, 32 + or a6, t2, a6 + slli t2, t0, 32 + srli t0, t0, 32 + or a7, t2, a7 + slli t2, t1, 32 + srli t1, t1, 32 + or t0, t2, t0 + slli t2, a4, 32 + or t1, t2, t1 + addi a2, a2, -32 + sd a6, 0(a3) + sd a7, 8(a3) + sd t0, 16(a3) + sd t1, 24(a3) + addi a3, a3, 32 + addi a1, a1, 32 + bltu a5, a2, .LBBmemcpy0_33 +# %bb.34: + addi a1, a1, -28 + li a4, 32 + bltu a2, a4, .LBBmemcpy0_9 + j .LBBmemcpy0_44 .LBBmemcpy0_35: - lb a1, 0(a4) - lb a5, 1(a4) - lb a6, 2(a4) - sb a1, 0(a3) + srli a5, a4, 8 + srli a6, a4, 16 + srli a7, a4, 24 + srli t0, a4, 32 + sb a4, 0(a3) sb a5, 1(a3) - lb a1, 3(a4) sb a6, 2(a3) - lb a5, 4(a4) - lb a6, 5(a4) - sb a1, 3(a3) - lb a1, 6(a4) - sb a5, 4(a3) - sb a6, 5(a3) - lb a5, 7(a4) - sb a1, 6(a3) - lb a1, 8(a4) - lb a6, 9(a4) - sb a5, 7(a3) - lb a5, 10(a4) - sb a1, 8(a3) - sb a6, 9(a3) - lb a1, 11(a4) - sb a5, 10(a3) - lb a5, 12(a4) - lb a6, 13(a4) - sb a1, 11(a3) - lb a1, 14(a4) - sb a5, 12(a3) - sb a6, 13(a3) - lb a5, 15(a4) - sb a1, 14(a3) - addi a4, a4, 16 - addi a1, a3, 16 - sb a5, 15(a3) - mv a3, a1 - andi a1, a2, 8 - beqz a1, .LBBmemcpy0_28 -.LBBmemcpy0_36: - lb a1, 0(a4) - lb a5, 1(a4) - lb a6, 2(a4) - sb a1, 0(a3) + sb a7, 3(a3) + srli a5, a4, 40 + addi a2, a2, -6 + sb t0, 4(a3) + sb a5, 5(a3) + addi a3, a3, 6 + addi a1, a1, 32 + li a5, 33 +.LBBmemcpy0_36: # =>This Inner Loop Header: Depth=1 + srli a6, a4, 48 + ld a7, -24(a1) + ld t0, -16(a1) + ld t1, -8(a1) + ld a4, 0(a1) + slli t2, a7, 16 + srli a7, a7, 48 + or a6, t2, a6 + slli t2, t0, 16 + srli t0, t0, 48 + or a7, t2, a7 + slli t2, t1, 16 + srli t1, t1, 48 + or t0, t2, t0 + slli t2, a4, 16 + or t1, t2, t1 + addi a2, a2, -32 + sd a6, 0(a3) + sd a7, 8(a3) + sd t0, 16(a3) + sd t1, 24(a3) + addi a3, a3, 32 + addi a1, a1, 32 + bltu a5, a2, .LBBmemcpy0_36 +# %bb.37: + addi a1, a1, -26 + li a4, 32 + bltu a2, a4, .LBBmemcpy0_9 + j .LBBmemcpy0_44 +.LBBmemcpy0_38: + srli a5, a4, 8 + addi a2, a2, -2 + sb a4, 0(a3) + sb a5, 1(a3) + addi a3, a3, 2 + addi a1, a1, 32 + li a5, 37 +.LBBmemcpy0_39: # =>This Inner Loop Header: Depth=1 + srli a6, a4, 16 + ld a7, -24(a1) + ld t0, -16(a1) + ld t1, -8(a1) + ld a4, 0(a1) + slli t2, a7, 48 + srli a7, a7, 16 + or a6, t2, a6 + slli t2, t0, 48 + srli t0, t0, 16 + or a7, t2, a7 + slli t2, t1, 48 + srli t1, t1, 16 + or t0, t2, t0 + slli t2, a4, 48 + or t1, t2, t1 + addi a2, a2, -32 + sd a6, 0(a3) + sd a7, 8(a3) + sd t0, 16(a3) + sd t1, 24(a3) + addi a3, a3, 32 + addi a1, a1, 32 + bltu a5, a2, .LBBmemcpy0_39 +# %bb.40: + addi a1, a1, -30 + li a4, 32 + bltu a2, a4, .LBBmemcpy0_9 + j .LBBmemcpy0_44 +.LBBmemcpy0_41: + sb a4, 0(a3) + addi a3, a3, 1 + addi a2, a2, -1 + addi a1, a1, 32 + li a5, 38 +.LBBmemcpy0_42: # =>This Inner Loop Header: Depth=1 + srli a6, a4, 8 + ld a7, -24(a1) + ld t0, -16(a1) + ld t1, -8(a1) + ld a4, 0(a1) + slli t2, a7, 56 + srli a7, a7, 8 + or a6, t2, a6 + slli t2, t0, 56 + srli t0, t0, 8 + or a7, t2, a7 + slli t2, t1, 56 + srli t1, t1, 8 + or t0, t2, t0 + slli t2, a4, 56 + or t1, t2, t1 + addi a2, a2, -32 + sd a6, 0(a3) + sd a7, 8(a3) + sd t0, 16(a3) + sd t1, 24(a3) + addi a3, a3, 32 + addi a1, a1, 32 + bltu a5, a2, .LBBmemcpy0_42 +# %bb.43: + addi a1, a1, -31 + li a4, 32 + bltu a2, a4, .LBBmemcpy0_9 +.LBBmemcpy0_44: + lbu a4, 0(a1) + lbu a5, 1(a1) + lbu a6, 2(a1) + lbu a7, 3(a1) + lbu t0, 4(a1) + lbu t1, 5(a1) + lbu t2, 6(a1) + lbu t3, 7(a1) + sb a4, 0(a3) sb a5, 1(a3) - lb a1, 3(a4) sb a6, 2(a3) - lb a5, 4(a4) - lb a6, 5(a4) - sb a1, 3(a3) - lb a1, 6(a4) - sb a5, 4(a3) - sb a6, 5(a3) - lb a5, 7(a4) - sb a1, 6(a3) - addi a4, a4, 8 - addi a1, a3, 8 - sb a5, 7(a3) - mv a3, a1 - andi a1, a2, 4 - bnez a1, .LBBmemcpy0_29 - j .LBBmemcpy0_30 -.Lfuncmemcpy_end0: - .size memcpy, .Lfuncmemcpy_end0-memcpy - - .ident "Ubuntu clang version 14.0.6-++20220622053131+f28c006a5895-1~exp1~20220622173215.157" + sb a7, 3(a3) + lbu a4, 8(a1) + lbu a5, 9(a1) + lbu a6, 10(a1) + lbu a7, 11(a1) + sb t0, 4(a3) + sb t1, 5(a3) + sb t2, 6(a3) + sb t3, 7(a3) + lbu t0, 12(a1) + lbu t1, 13(a1) + lbu t2, 14(a1) + lbu t3, 15(a1) + sb a4, 8(a3) + sb a5, 9(a3) + sb a6, 10(a3) + sb a7, 11(a3) + lbu a4, 16(a1) + lbu a5, 17(a1) + lbu a6, 18(a1) + lbu a7, 19(a1) + sb t0, 12(a3) + sb t1, 13(a3) + sb t2, 14(a3) + sb t3, 15(a3) + lbu t0, 20(a1) + lbu t1, 21(a1) + lbu t2, 22(a1) + lbu t3, 23(a1) + sb a4, 16(a3) + sb a5, 17(a3) + sb a6, 18(a3) + sb a7, 19(a3) + lbu a4, 24(a1) + lbu a5, 25(a1) + lbu a6, 26(a1) + lbu a7, 27(a1) + sb t0, 20(a3) + sb t1, 21(a3) + sb t2, 22(a3) + sb t3, 23(a3) + lbu t0, 28(a1) + lbu t1, 29(a1) + lbu t2, 30(a1) + lbu t3, 31(a1) + addi a1, a1, 32 + sb a4, 24(a3) + sb a5, 25(a3) + sb a6, 26(a3) + sb a7, 27(a3) + addi a4, a3, 32 + sb t0, 28(a3) + sb t1, 29(a3) + sb t2, 30(a3) + sb t3, 31(a3) + mv a3, a4 + andi a4, a2, 16 + beqz a4, .LBBmemcpy0_10 +.LBBmemcpy0_45: + lbu a4, 0(a1) + lbu a5, 1(a1) + lbu a6, 2(a1) + lbu a7, 3(a1) + lbu t0, 4(a1) + lbu t1, 5(a1) + lbu t2, 6(a1) + lbu t3, 7(a1) + sb a4, 0(a3) + sb a5, 1(a3) + sb a6, 2(a3) + sb a7, 3(a3) + lbu a4, 8(a1) + lbu a5, 9(a1) + lbu a6, 10(a1) + lbu a7, 11(a1) + sb t0, 4(a3) + sb t1, 5(a3) + sb t2, 6(a3) + sb t3, 7(a3) + lbu t0, 12(a1) + lbu t1, 13(a1) + lbu t2, 14(a1) + lbu t3, 15(a1) + addi a1, a1, 16 + sb a4, 8(a3) + sb a5, 9(a3) + sb a6, 10(a3) + sb a7, 11(a3) + addi a4, a3, 16 + sb t0, 12(a3) + sb t1, 13(a3) + sb t2, 14(a3) + sb t3, 15(a3) + mv a3, a4 + andi a4, a2, 8 + beqz a4, .LBBmemcpy0_11 +.LBBmemcpy0_46: + lbu a4, 0(a1) + lbu a5, 1(a1) + lbu a6, 2(a1) + lbu a7, 3(a1) + lbu t0, 4(a1) + lbu t1, 5(a1) + lbu t2, 6(a1) + lbu t3, 7(a1) + addi a1, a1, 8 + sb a4, 0(a3) + sb a5, 1(a3) + sb a6, 2(a3) + sb a7, 3(a3) + addi a4, a3, 8 + sb t0, 4(a3) + sb t1, 5(a3) + sb t2, 6(a3) + sb t3, 7(a3) + mv a3, a4 + andi a4, a2, 4 + beqz a4, .LBBmemcpy0_48 +.LBBmemcpy0_47: + lbu a4, 0(a1) + lbu a5, 1(a1) + lbu a6, 2(a1) + lbu a7, 3(a1) + addi a1, a1, 4 + addi t0, a3, 4 + sb a4, 0(a3) + sb a5, 1(a3) + sb a6, 2(a3) + sb a7, 3(a3) + mv a3, t0 +.LBBmemcpy0_48: + andi a4, a2, 2 + bnez a4, .LBBmemcpy0_51 +# %bb.49: + andi a2, a2, 1 + bnez a2, .LBBmemcpy0_52 +.LBBmemcpy0_50: + ret +.LBBmemcpy0_51: + lbu a4, 0(a1) + lbu a5, 1(a1) + addi a1, a1, 2 + addi a6, a3, 2 + sb a4, 0(a3) + sb a5, 1(a3) + mv a3, a6 + andi a2, a2, 1 + beqz a2, .LBBmemcpy0_50 +.LBBmemcpy0_52: + lbu a1, 0(a1) + sb a1, 0(a3) + ret +.Lfunc_end0: + .size memcpy, .Lfunc_end0-memcpy + .section .rodata,"a",@progbits + .p2align 2, 0x0 +.LJTI0_0: + .word .LBBmemcpy0_13 + .word .LBBmemcpy0_35 + .word .LBBmemcpy0_29 + .word .LBBmemcpy0_32 + .word .LBBmemcpy0_26 + .word .LBBmemcpy0_38 + .word .LBBmemcpy0_41 + # -- End function + .ident "Homebrew clang version 20.1.7" .section ".note.GNU-stack","",@progbits - .addrsig \ No newline at end of file + .addrsig diff --git a/sdk/sdk/src/memset.s b/sdk/sdk/src/memset.s index a19a11da..7415aea3 100644 --- a/sdk/sdk/src/memset.s +++ b/sdk/sdk/src/memset.s @@ -173,7 +173,7 @@ // All other files which have no copyright comments are original works // produced specifically for use as part of this library, written either // by Rich Felker, the main author of the library, or by one or more -// contibutors listed above. Details on authorship of individual files +// contributors listed above. Details on authorship of individual files // can be found in the git version control history of the project. The // omission of copyright and license comments in each file is in the // interest of source tree size. diff --git a/sdk/sdk/src/riscv_ecalls/bigint.rs b/sdk/sdk/src/riscv_ecalls/bigint.rs index b7d1e701..b67b2536 100644 --- a/sdk/sdk/src/riscv_ecalls/bigint.rs +++ b/sdk/sdk/src/riscv_ecalls/bigint.rs @@ -1,7 +1,7 @@ use super::syscall_uint256_mulmod; /// The number of limbs in a "uint256". -const N: usize = 8; +const N: usize = 4; /// Sets `result` to be `(x op y) % modulus`. /// @@ -15,34 +15,34 @@ const N: usize = 8; #[allow(unused_variables)] #[no_mangle] pub extern "C" fn sys_bigint( - result: *mut [u32; N], - op: u32, - x: *const [u32; N], - y: *const [u32; N], - modulus: *const [u32; N], + result: *mut [u64; N], + op: u64, + x: *const [u64; N], + y: *const [u64; N], + modulus: *const [u64; N], ) { // Instantiate a new uninitialized array of words to place the concatenated y and modulus. - let mut concat_y_modulus = core::mem::MaybeUninit::<[u32; N * 2]>::uninit(); + let mut concat_y_modulus = core::mem::MaybeUninit::<[u64; N * 2]>::uninit(); unsafe { - let result_ptr = result as *mut u32; - let x_ptr = x as *const u32; - let y_ptr = y as *const u32; - let concat_ptr = concat_y_modulus.as_mut_ptr() as *mut u32; + let result_ptr = result as *mut u64; + let x_ptr = x as *const u64; + let y_ptr = y as *const u64; + let concat_ptr = concat_y_modulus.as_mut_ptr() as *mut u64; // First copy the y value into the concatenated array. core::ptr::copy(y_ptr, concat_ptr, N); // Then, copy the modulus value into the concatenated array. Add the width of the y value // to the pointer to place the modulus value after the y value. - core::ptr::copy(modulus as *const u32, concat_ptr.add(N), N); + core::ptr::copy(modulus as *const u64, concat_ptr.add(N), N); // Copy x into the result array, as our syscall will write the result into the first input. - core::ptr::copy(x as *const u32, result_ptr, N); + core::ptr::copy(x as *const u64, result_ptr, N); // Call the uint256_mul syscall to multiply the x value with the concatenated y and modulus. // This syscall writes the result in-place, so it will mutate the result ptr appropriately. - let result_ptr = result_ptr as *mut [u32; N]; - let concat_ptr = concat_ptr as *mut [u32; N]; + let result_ptr = result_ptr as *mut [u64; N]; + let concat_ptr = concat_ptr as *mut [u64; N]; syscall_uint256_mulmod(result_ptr, concat_ptr); } } diff --git a/sdk/sdk/src/riscv_ecalls/bls12381.rs b/sdk/sdk/src/riscv_ecalls/bls12381.rs index 6e117536..765285ab 100644 --- a/sdk/sdk/src/riscv_ecalls/bls12381.rs +++ b/sdk/sdk/src/riscv_ecalls/bls12381.rs @@ -11,7 +11,7 @@ use core::arch::asm; /// byte boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_add(p: *mut [u32; 24], q: *const [u32; 24]) { +pub extern "C" fn syscall_bls12381_add(p: *mut [u64; 12], q: *const [u64; 12]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -36,7 +36,7 @@ pub extern "C" fn syscall_bls12381_add(p: *mut [u32; 24], q: *const [u32; 24]) { /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_double(p: *mut [u32; 24]) { +pub extern "C" fn syscall_bls12381_double(p: *mut [u64; 12]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -50,6 +50,11 @@ pub extern "C" fn syscall_bls12381_double(p: *mut [u32; 24]) { /// Decompresses a compressed BLS12-381 point. /// +/// The array represents two field elements. When considered as a byte array, the representation is +/// big-endian. This means that the `u64`s are actually byte-reversed due to the little-endian +/// architecture. The reason the type signature requires a u64 array is because we want the pointers +/// to be aligned to the architecture's register bit widths. +/// /// The first half of the input array should contain the X coordinate. The second half of the input /// array will be overwritten with the Y coordinate. /// @@ -59,9 +64,12 @@ pub extern "C" fn syscall_bls12381_double(p: *mut [u32; 24]) { /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_decompress(point: &mut [u8; 96], sign_bit: bool) { +pub extern "C" fn syscall_bls12381_decompress(point: &mut [u64; 12], sign_bit: bool) { #[cfg(target_os = "zkvm")] { + // SAFETY: Both pointee types have the same size. The destination has a finer alignment than + // the source. + let point = unsafe { core::mem::transmute::<&mut [u64; 12], &mut [u8; 12 * 8]>(point) }; // Memory system/FpOps are little endian so we'll just flip the whole array before/after point.reverse(); let p = point.as_mut_ptr(); diff --git a/sdk/sdk/src/riscv_ecalls/bn254.rs b/sdk/sdk/src/riscv_ecalls/bn254.rs index fca4f6aa..91469996 100644 --- a/sdk/sdk/src/riscv_ecalls/bn254.rs +++ b/sdk/sdk/src/riscv_ecalls/bn254.rs @@ -11,7 +11,7 @@ use core::arch::asm; /// byte boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bn254_add(p: *mut [u32; 16], q: *const [u32; 16]) { +pub extern "C" fn syscall_bn254_add(p: *mut [u64; 8], q: *const [u64; 8]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -36,7 +36,7 @@ pub extern "C" fn syscall_bn254_add(p: *mut [u32; 16], q: *const [u32; 16]) { /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bn254_double(p: *mut [u32; 16]) { +pub extern "C" fn syscall_bn254_double(p: *mut [u64; 8]) { #[cfg(target_os = "zkvm")] unsafe { asm!( diff --git a/sdk/sdk/src/riscv_ecalls/ed25519.rs b/sdk/sdk/src/riscv_ecalls/ed25519.rs index d9a45b42..50ffc37f 100644 --- a/sdk/sdk/src/riscv_ecalls/ed25519.rs +++ b/sdk/sdk/src/riscv_ecalls/ed25519.rs @@ -11,7 +11,7 @@ use core::arch::asm; /// byte boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_ed_add(p: *mut [u32; 16], q: *const [u32; 16]) { +pub extern "C" fn syscall_ed_add(p: *mut [u64; 8], q: *const [u64; 8]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -26,7 +26,7 @@ pub extern "C" fn syscall_ed_add(p: *mut [u32; 16], q: *const [u32; 16]) { unreachable!() } -/// Decompresses a compressed Edwards point. +/// Decompresses a compressed Edwards point, encoded as a little-endian u64 array. /// /// The second half of the input array should contain the compressed Y point with the final bit as /// the sign bit. The first half of the input array will be overwritten with the decompressed point, @@ -38,11 +38,14 @@ pub extern "C" fn syscall_ed_add(p: *mut [u32; 16], q: *const [u32; 16]) { /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_ed_decompress(point: &mut [u8; 64]) { +pub extern "C" fn syscall_ed_decompress(point: &mut [u64; 8]) { #[cfg(target_os = "zkvm")] { - let sign = point[63] >> 7; - point[63] &= 0b0111_1111; + const SIGN_OFFSET: u32 = 1; + const SIGN_MASK: u64 = 1u64 << (u64::BITS - SIGN_OFFSET); + let sign = ((point[7] & SIGN_MASK) != 0) as u64; + point[7] <<= SIGN_OFFSET; + point[7] >>= SIGN_OFFSET; let p = point.as_mut_ptr() as *mut u8; unsafe { asm!( diff --git a/sdk/sdk/src/riscv_ecalls/fptower.rs b/sdk/sdk/src/riscv_ecalls/fptower.rs index aab9df6e..5f6f59d6 100644 --- a/sdk/sdk/src/riscv_ecalls/fptower.rs +++ b/sdk/sdk/src/riscv_ecalls/fptower.rs @@ -6,7 +6,7 @@ use core::arch::asm; /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_fp_addmod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bls12381_fp_addmod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -26,7 +26,7 @@ pub extern "C" fn syscall_bls12381_fp_addmod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_fp_submod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bls12381_fp_submod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -46,7 +46,7 @@ pub extern "C" fn syscall_bls12381_fp_submod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_fp_mulmod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bls12381_fp_mulmod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -66,7 +66,7 @@ pub extern "C" fn syscall_bls12381_fp_mulmod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_fp2_addmod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bls12381_fp2_addmod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -86,7 +86,7 @@ pub extern "C" fn syscall_bls12381_fp2_addmod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_fp2_submod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bls12381_fp2_submod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -106,7 +106,7 @@ pub extern "C" fn syscall_bls12381_fp2_submod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bls12381_fp2_mulmod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bls12381_fp2_mulmod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -126,7 +126,7 @@ pub extern "C" fn syscall_bls12381_fp2_mulmod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bn254_fp_addmod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bn254_fp_addmod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -146,7 +146,7 @@ pub extern "C" fn syscall_bn254_fp_addmod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bn254_fp_submod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bn254_fp_submod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -166,7 +166,7 @@ pub extern "C" fn syscall_bn254_fp_submod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bn254_fp_mulmod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bn254_fp_mulmod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -186,7 +186,7 @@ pub extern "C" fn syscall_bn254_fp_mulmod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bn254_fp2_addmod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bn254_fp2_addmod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -206,7 +206,7 @@ pub extern "C" fn syscall_bn254_fp2_addmod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bn254_fp2_submod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bn254_fp2_submod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -226,7 +226,7 @@ pub extern "C" fn syscall_bn254_fp2_submod(x: *mut u32, y: *const u32) { /// The result is written over the first input. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_bn254_fp2_mulmod(x: *mut u32, y: *const u32) { +pub extern "C" fn syscall_bn254_fp2_mulmod(x: *mut u64, y: *const u64) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -240,63 +240,3 @@ pub extern "C" fn syscall_bn254_fp2_mulmod(x: *mut u32, y: *const u32) { #[cfg(not(target_os = "zkvm"))] unreachable!() } - -/// Fp addition operation. -/// -/// The result is written over the first input. -#[allow(unused_variables)] -#[no_mangle] -pub extern "C" fn syscall_secp256k1_fp_addmod(x: *mut u32, y: *const u32) { - #[cfg(target_os = "zkvm")] - unsafe { - asm!( - "ecall", - in("t0") crate::riscv_ecalls::SECP256K1_FP_ADD, - in("a0") x, - in("a1") y, - ); - } - - #[cfg(not(target_os = "zkvm"))] - unreachable!() -} - -/// Fp subtraction operation. -/// -/// The result is written over the first input. -#[allow(unused_variables)] -#[no_mangle] -pub extern "C" fn syscall_secp256k1_fp_submod(x: *mut u32, y: *const u32) { - #[cfg(target_os = "zkvm")] - unsafe { - asm!( - "ecall", - in("t0") crate::riscv_ecalls::SECP256K1_FP_SUB, - in("a0") x, - in("a1") y, - ); - } - - #[cfg(not(target_os = "zkvm"))] - unreachable!() -} - -/// Fp multiplication operation. -/// -/// The result is written over the first input. -#[allow(unused_variables)] -#[no_mangle] -pub extern "C" fn syscall_secp256k1_fp_mulmod(x: *mut u32, y: *const u32) { - #[cfg(target_os = "zkvm")] - unsafe { - asm!( - "ecall", - in("t0") crate::riscv_ecalls::SECP256K1_FP_MUL, - in("a0") x, - in("a1") y, - ); - } - - #[cfg(not(target_os = "zkvm"))] - unreachable!() -} diff --git a/sdk/sdk/src/riscv_ecalls/secp256k1.rs b/sdk/sdk/src/riscv_ecalls/secp256k1.rs index 524952ae..d68aa194 100644 --- a/sdk/sdk/src/riscv_ecalls/secp256k1.rs +++ b/sdk/sdk/src/riscv_ecalls/secp256k1.rs @@ -12,7 +12,7 @@ use core::arch::asm; /// secp256k1 curve, and that `p` and `q` are not equal to each other. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_secp256k1_add(p: *mut [u32; 16], q: *mut [u32; 16]) { +pub extern "C" fn syscall_secp256k1_add(p: *mut [u64; 8], q: *mut [u64; 8]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -37,7 +37,7 @@ pub extern "C" fn syscall_secp256k1_add(p: *mut [u32; 16], q: *mut [u32; 16]) { /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_secp256k1_double(p: *mut [u32; 16]) { +pub extern "C" fn syscall_secp256k1_double(p: *mut [u64; 8]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -54,9 +54,14 @@ pub extern "C" fn syscall_secp256k1_double(p: *mut [u32; 16]) { /// Decompresses a compressed Secp256k1 point. /// -/// The input array should be 64 bytes long, with the first 32 bytes containing the X coordinate in -/// big-endian format. The second half of the input will be overwritten with the Y coordinate of the -/// decompressed point in big-endian format using the point's parity (is_odd). +/// The array represents two field elements. When considered as a byte array, the representation is +/// big-endian. This means that the `u64`s are actually byte-reversed due to the little-endian +/// architecture. The reason the type signature requires a u64 array is because we want the pointers +/// to be aligned to the architecture's register bit widths. +/// +/// The input array should be 64 bytes long, with the first 32 bytes containing the X coordinate. +/// The second half of the input will be overwritten with the Y coordinate of the decompressed point +/// using the point's parity (is_odd). /// /// ### Safety /// @@ -64,9 +69,12 @@ pub extern "C" fn syscall_secp256k1_double(p: *mut [u32; 16]) { /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_secp256k1_decompress(point: &mut [u8; 64], is_odd: bool) { +pub extern "C" fn syscall_secp256k1_decompress(point: &mut [u64; 8], is_odd: bool) { #[cfg(target_os = "zkvm")] { + // SAFETY: Both pointee types have the same size. The destination has a finer alignment than + // the source. + let point = unsafe { core::mem::transmute::<&mut [u64; 8], &mut [u8; 64]>(point) }; // Memory system/FpOps are little endian so we'll just flip the whole array before/after point.reverse(); let p = point.as_mut_ptr(); diff --git a/sdk/sdk/src/riscv_ecalls/secp256r1.rs b/sdk/sdk/src/riscv_ecalls/secp256r1.rs index 24aa5120..f9bcd2ee 100644 --- a/sdk/sdk/src/riscv_ecalls/secp256r1.rs +++ b/sdk/sdk/src/riscv_ecalls/secp256r1.rs @@ -1,7 +1,7 @@ #[cfg(target_os = "zkvm")] use core::arch::asm; -/// Adds two Secp256r1 points. +/// Adds two Secp256k1 points. /// /// The result is stored in the first point. /// @@ -9,10 +9,10 @@ use core::arch::asm; /// /// The caller must ensure that `p` and `q` are valid pointers to data that is aligned along a four /// byte boundary. Additionally, the caller must ensure that `p` and `q` are valid points on the -/// secp256r1 curve, and that `p` and `q` are not equal to each other. +/// secp256k1 curve, and that `p` and `q` are not equal to each other. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_secp256r1_add(p: *mut [u32; 16], q: *mut [u32; 16]) { +pub extern "C" fn syscall_secp256r1_add(p: *mut [u64; 8], q: *mut [u64; 8]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -27,7 +27,7 @@ pub extern "C" fn syscall_secp256r1_add(p: *mut [u32; 16], q: *mut [u32; 16]) { unreachable!() } -/// Double a Secp256r1 point. +/// Double a Secp256k1 point. /// /// The result is stored in-place in the supplied buffer. /// @@ -37,7 +37,7 @@ pub extern "C" fn syscall_secp256r1_add(p: *mut [u32; 16], q: *mut [u32; 16]) { /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_secp256r1_double(p: *mut [u32; 16]) { +pub extern "C" fn syscall_secp256r1_double(p: *mut [u64; 8]) { #[cfg(target_os = "zkvm")] unsafe { asm!( @@ -52,11 +52,16 @@ pub extern "C" fn syscall_secp256r1_double(p: *mut [u32; 16]) { unreachable!() } -/// Decompresses a compressed Secp256r1 point. +/// Decompresses a compressed Secp256k1 point. /// -/// The input array should be 64 bytes long, with the first 32 bytes containing the X coordinate in -/// big-endian format. The second half of the input will be overwritten with the Y coordinate of the -/// decompressed point in big-endian format using the point's parity (is_odd). +/// The array represents two field elements. When considered as a byte array, the representation is +/// big-endian. This means that the `u64`s are actually byte-reversed due to the little-endian +/// architecture. The reason the type signature requires a u64 array is because we want the pointers +/// to be aligned to the architecture's register bit widths. +/// +/// The input array should be 64 bytes long, with the first 32 bytes containing the X coordinate. +/// The second half of the input will be overwritten with the Y coordinate of the decompressed point +/// using the point's parity (is_odd). /// /// ### Safety /// @@ -64,9 +69,12 @@ pub extern "C" fn syscall_secp256r1_double(p: *mut [u32; 16]) { /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_secp256r1_decompress(point: &mut [u8; 64], is_odd: bool) { +pub extern "C" fn syscall_secp256r1_decompress(point: &mut [u64; 8], is_odd: bool) { #[cfg(target_os = "zkvm")] { + // SAFETY: Both pointee types have the same size. The destination has a finer alignment than + // the source. + let point = unsafe { core::mem::transmute::<&mut [u64; 8], &mut [u8; 64]>(point) }; // Memory system/FpOps are little endian so we'll just flip the whole array before/after point.reverse(); let p = point.as_mut_ptr(); diff --git a/sdk/sdk/src/riscv_ecalls/sha_compress.rs b/sdk/sdk/src/riscv_ecalls/sha_compress.rs index 8a64ef41..43ae4607 100644 --- a/sdk/sdk/src/riscv_ecalls/sha_compress.rs +++ b/sdk/sdk/src/riscv_ecalls/sha_compress.rs @@ -9,7 +9,7 @@ use core::arch::asm; /// four byte boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_sha256_compress(w: *mut [u32; 64], state: *mut [u32; 8]) { +pub extern "C" fn syscall_sha256_compress(w: *mut [u64; 64], state: *mut [u64; 8]) { #[cfg(target_os = "zkvm")] unsafe { asm!( diff --git a/sdk/sdk/src/riscv_ecalls/sha_extend.rs b/sdk/sdk/src/riscv_ecalls/sha_extend.rs index ab68cd47..96dee84b 100644 --- a/sdk/sdk/src/riscv_ecalls/sha_extend.rs +++ b/sdk/sdk/src/riscv_ecalls/sha_extend.rs @@ -9,7 +9,7 @@ use core::arch::asm; /// boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_sha256_extend(w: *mut [u32; 64]) { +pub extern "C" fn syscall_sha256_extend(w: *mut [u64; 64]) { #[cfg(target_os = "zkvm")] unsafe { asm!( diff --git a/sdk/sdk/src/riscv_ecalls/sys.rs b/sdk/sdk/src/riscv_ecalls/sys.rs index 1d6f75ce..0a9e134a 100644 --- a/sdk/sdk/src/riscv_ecalls/sys.rs +++ b/sdk/sdk/src/riscv_ecalls/sys.rs @@ -26,7 +26,7 @@ static SYS_RAND_WARNING: std::sync::Once = std::sync::Once::new(); #[no_mangle] pub unsafe extern "C" fn sys_rand(recv_buf: *mut u8, words: usize) { SYS_RAND_WARNING.call_once(|| { - println!("WARNING: Using insecure random number generator."); + eprintln!("WARNING: Using insecure random number generator."); }); let mut rng = RNG.lock().unwrap(); for i in 0..words { diff --git a/sdk/sdk/src/riscv_ecalls/uint256_mul.rs b/sdk/sdk/src/riscv_ecalls/uint256_mul.rs index 9e7d1af8..e6ddcbb6 100644 --- a/sdk/sdk/src/riscv_ecalls/uint256_mul.rs +++ b/sdk/sdk/src/riscv_ecalls/uint256_mul.rs @@ -11,7 +11,7 @@ use core::arch::asm; /// byte boundary. #[allow(unused_variables)] #[no_mangle] -pub extern "C" fn syscall_uint256_mulmod(x: *mut [u32; 8], y: *const [u32; 8]) { +pub extern "C" fn syscall_uint256_mulmod(x: *mut [u64; 4], y: *const [u64; 4]) { #[cfg(target_os = "zkvm")] unsafe { asm!( diff --git a/sdk/sdk/src/riscv_ecalls/unconstrained.rs b/sdk/sdk/src/riscv_ecalls/unconstrained.rs index e549ce7b..6659c391 100644 --- a/sdk/sdk/src/riscv_ecalls/unconstrained.rs +++ b/sdk/sdk/src/riscv_ecalls/unconstrained.rs @@ -16,7 +16,7 @@ pub fn syscall_enter_unconstrained() -> bool { #[cfg(not(target_os = "zkvm"))] { - println!("Entering unconstrained execution block"); + eprintln!("Entering unconstrained execution block"); continue_unconstrained = 1; } @@ -35,5 +35,5 @@ pub fn syscall_exit_unconstrained() { } #[cfg(not(target_os = "zkvm"))] - println!("Exiting unconstrained execution block"); + eprintln!("Exiting unconstrained execution block"); } diff --git a/vm/src/emulator/riscv/hook/bls.rs b/vm/src/emulator/riscv/hook/bls.rs new file mode 100644 index 00000000..5e4ba41c --- /dev/null +++ b/vm/src/emulator/riscv/hook/bls.rs @@ -0,0 +1,61 @@ +use super::super::emulator::RiscvEmulator; +use crate::chips::gadgets::{ + curves::weierstrass::bls381::Bls381BaseField, utils::field_params::FieldParameters, +}; +use num_bigint::BigUint; +use num_traits::Zero; + +const NQR_BLS12_381: [u8; 48] = { + let mut nqr = [0; 48]; + nqr[47] = 2; + nqr +}; + +fn pad_to_be(val: &BigUint, len: usize) -> Vec { + let mut bytes = val.to_bytes_le(); + bytes.resize(len, 0); + bytes.reverse(); + bytes +} + +pub fn hook_bls12_381_sqrt(_: &RiscvEmulator, buf: &[u8]) -> Vec> { + assert!(buf.len() == 48, "BLS12-381 sqrt input must be 48 bytes"); + + let field_element = BigUint::from_bytes_be(buf); + + if field_element.is_zero() { + return vec![vec![1], vec![0; 48]]; + } + + let modulus = BigUint::from_bytes_le(Bls381BaseField::MODULUS); + let exp = (&modulus + BigUint::from(1u64)) / BigUint::from(4u64); + let sqrt = field_element.modpow(&exp, &modulus); + + let square = (&sqrt * &sqrt) % &modulus; + if square != field_element { + let nqr = BigUint::from_bytes_be(&NQR_BLS12_381); + let qr = (&nqr * &field_element) % &modulus; + let root = qr.modpow(&exp, &modulus); + + assert!((&root * &root) % &modulus == qr, "NQR sanity check failed"); + + return vec![vec![0], pad_to_be(&root, 48)]; + } + + vec![vec![1], pad_to_be(&sqrt, 48)] +} + +pub fn hook_bls12_381_inverse(_: &RiscvEmulator, buf: &[u8]) -> Vec> { + assert!(buf.len() == 48, "BLS12-381 inverse input must be 48 bytes"); + + let field_element = BigUint::from_bytes_be(buf); + assert!( + !field_element.is_zero(), + "Field element is the additive identity" + ); + + let modulus = BigUint::from_bytes_le(Bls381BaseField::MODULUS); + let inverse = field_element.modpow(&(&modulus - BigUint::from(2u64)), &modulus); + + vec![pad_to_be(&inverse, 48)] +} diff --git a/vm/src/emulator/riscv/hook/fp.rs b/vm/src/emulator/riscv/hook/fp.rs new file mode 100644 index 00000000..42bef896 --- /dev/null +++ b/vm/src/emulator/riscv/hook/fp.rs @@ -0,0 +1,121 @@ +use super::super::emulator::RiscvEmulator; +use num_bigint::BigUint; +use num_traits::{One, Zero}; + +fn pad_to_be(val: &BigUint, len: usize) -> Vec { + let mut bytes = val.to_bytes_le(); + bytes.resize(len, 0); + bytes.reverse(); + bytes +} + +pub fn hook_fp_inverse(_: &RiscvEmulator, buf: &[u8]) -> Vec> { + let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize; + + assert!(buf.len() == 4 + 2 * len, "FpOp Hook: Invalid buffer length"); + + let buf = &buf[4..]; + let element = BigUint::from_bytes_be(&buf[..len]); + let modulus = BigUint::from_bytes_be(&buf[len..2 * len]); + + assert!(!element.is_zero(), "FpOp: Inverse called with zero"); + + let inverse = element.modpow(&(&modulus - BigUint::from(2u64)), &modulus); + + vec![pad_to_be(&inverse, len)] +} + +pub fn hook_fp_sqrt(_: &RiscvEmulator, buf: &[u8]) -> Vec> { + let len: usize = u32::from_be_bytes(buf[0..4].try_into().unwrap()) as usize; + + assert!(buf.len() == 4 + 3 * len, "FpOp Hook: Invalid buffer length"); + + let buf = &buf[4..]; + let element = BigUint::from_bytes_be(&buf[..len]); + let modulus = BigUint::from_bytes_be(&buf[len..2 * len]); + let nqr = BigUint::from_bytes_be(&buf[2 * len..3 * len]); + + assert!( + element < modulus, + "Element is not less than modulus, the hook only accepts canonical representations" + ); + assert!( + nqr < modulus, + "NQR is zero or non-canonical, the hook only accepts canonical representations" + ); + + if element.is_zero() { + return vec![vec![1], vec![0; len]]; + } + + if let Some(root) = sqrt_fp(&element, &modulus, &nqr) { + vec![vec![1], pad_to_be(&root, len)] + } else { + let qr = (&nqr * &element) % &modulus; + let root = sqrt_fp(&qr, &modulus, &nqr).unwrap(); + + vec![vec![0], pad_to_be(&root, len)] + } +} + +fn sqrt_fp(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option { + if modulus % BigUint::from(4u64) == BigUint::from(3u64) { + let maybe_root = element.modpow( + &((modulus + BigUint::from(1u64)) / BigUint::from(4u64)), + modulus, + ); + + return Some(maybe_root).filter(|root| root * root % modulus == *element); + } + + tonelli_shanks(element, modulus, nqr) +} + +#[allow(clippy::many_single_char_names)] +fn tonelli_shanks(element: &BigUint, modulus: &BigUint, nqr: &BigUint) -> Option { + if legendre_symbol(element, modulus) != BigUint::one() { + return None; + } + + let mut s = BigUint::zero(); + let mut q = modulus - BigUint::one(); + while &q % &BigUint::from(2u64) == BigUint::zero() { + s += BigUint::from(1u64); + q /= BigUint::from(2u64); + } + + let z = nqr; + let mut c = z.modpow(&q, modulus); + let mut r = element.modpow(&((&q + BigUint::from(1u64)) / BigUint::from(2u64)), modulus); + let mut t = element.modpow(&q, modulus); + let mut m = s; + + while t != BigUint::one() { + let mut i = BigUint::zero(); + let mut tt = t.clone(); + while tt != BigUint::one() { + tt = &tt * &tt % modulus; + i += BigUint::from(1u64); + + if i == m { + return None; + } + } + + let b_pow = BigUint::from(2u64).pow((&m - &i - BigUint::from(1u64)).try_into().unwrap()); + let b = c.modpow(&b_pow, modulus); + + r = &r * &b % modulus; + c = &b * &b % modulus; + t = &t * &c % modulus; + m = i; + } + + Some(r) +} + +fn legendre_symbol(element: &BigUint, modulus: &BigUint) -> BigUint { + assert!(!element.is_zero(), "FpOp: Legendre symbol of zero called."); + + element.modpow(&((modulus - BigUint::one()) / BigUint::from(2u64)), modulus) +} diff --git a/vm/src/emulator/riscv/hook/mod.rs b/vm/src/emulator/riscv/hook/mod.rs index 3234055f..a8599a7e 100644 --- a/vm/src/emulator/riscv/hook/mod.rs +++ b/vm/src/emulator/riscv/hook/mod.rs @@ -1,5 +1,7 @@ +mod bls; mod ecrecover; mod ed_decompress; +mod fp; use super::riscv_emulator::RiscvEmulator; use hashbrown::HashMap; @@ -10,10 +12,26 @@ const SECP256K1_ECRECOVER: u32 = 5; /// The file descriptor through which to access `hook_ed_decompress`. pub const FD_EDDECOMPRESS: u32 = 8; +/// The file descriptor through which to access `hook_fp_sqrt`. +pub const FD_FP_SQRT: u32 = 10; + +/// The file descriptor through which to access `hook_fp_inverse`. +pub const FD_FP_INV: u32 = 11; + +/// The file descriptor through which to access `hook_bls12_381_sqrt`. +pub const FD_BLS12_381_SQRT: u32 = 12; + +/// The file descriptor through which to access `hook_bls12_381_inverse`. +pub const FD_BLS12_381_INVERSE: u32 = 13; + pub fn default_hook_map() -> HashMap { let hooks: [(u32, Hook); _] = [ (SECP256K1_ECRECOVER, ecrecover::ecrecover), (FD_EDDECOMPRESS, ed_decompress::ed_decompress), + (FD_FP_SQRT, fp::hook_fp_sqrt), + (FD_FP_INV, fp::hook_fp_inverse), + (FD_BLS12_381_SQRT, bls::hook_bls12_381_sqrt), + (FD_BLS12_381_INVERSE, bls::hook_bls12_381_inverse), ]; HashMap::from_iter(hooks) }