Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions ceno_zkvm/src/scheme/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ pub(crate) fn infer_tower_product_witness<E: ExtensionField>(
#[cfg(test)]
mod tests {

use ff_ext::{FieldInto, GoldilocksExt2};
use ff_ext::{BabyBearExt4, FieldInto, GoldilocksExt2};
use itertools::Itertools;
use multilinear_extensions::{
commutative_op_mle_pair,
Expand All @@ -302,7 +302,7 @@ mod tests {

#[test]
fn test_infer_tower_witness() {
type E = GoldilocksExt2;
type E = BabyBearExt4;
let num_product_fanin = 2;
let last_layer: Vec<MultilinearExtension<E>> = vec![
vec![E::ONE, E::from_canonical_u64(2u64)].into_mle(),
Expand Down Expand Up @@ -454,7 +454,7 @@ mod tests {

#[test]
fn test_infer_tower_logup_witness() {
type E = GoldilocksExt2;
type E = BabyBearExt4;
let num_vars = 2;
let q: Vec<MultilinearExtension<E>> = vec![
vec![1, 2, 3, 4]
Expand Down
1 change: 1 addition & 0 deletions ff_ext/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ once_cell = "1.21.3"
p3.workspace = true
rand_core.workspace = true
serde.workspace = true
rand.workspace = true

[features]
nightly-features = ["p3/nightly-features"]
42 changes: 42 additions & 0 deletions ff_ext/src/babybear.rs
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,46 @@ pub mod impl_babybear {
.collect()
}
}

#[cfg(test)]
mod tests {
use p3::{
babybear::BabyBear,
field::{FieldAlgebra, FieldExtensionAlgebra},
};
use rand::thread_rng;

use crate::{BabyBearExt4, FromUniformBytes};

#[test]
fn test_ext_mul() {
for (a_limbs, b_limbs, c_limbs) in vec![
(vec![0, 1, 0, 0], vec![0, 0, 0, 1], vec![11, 0, 0, 0]), // x*x^3 = 11
(vec![0, 0, 1, 0], vec![0, 0, 0, 1], vec![0, 11, 0, 0]), // x^2*x^3 = 11*x
(vec![1, 2, 0, 0], vec![0, 0, 3, 4], vec![88, 0, 3, 10]), /* (1+2x)*(3x^2+4x^3) = 88 + 3x^2+10x^3 */
(vec![1, 2, 3, 4], vec![5, 6, 7, 8], vec![676, 588, 386, 60]),
] {
let a = BabyBearExt4::from_base_iter(
a_limbs.into_iter().map(BabyBear::from_canonical_u32),
);
let b = BabyBearExt4::from_base_iter(
b_limbs.into_iter().map(BabyBear::from_canonical_u32),
);
let c = a * b;
assert_eq!(
c,
BabyBearExt4::from_base_iter(
c_limbs.into_iter().map(BabyBear::from_canonical_u32)
)
);
}

// print one random example
let mut rng = thread_rng();
let a = BabyBearExt4::random(&mut rng);
let b = BabyBearExt4::random(&mut rng);
let c = a * b;
println!("a: {:?}, b: {:?}, c: {:?}", a, b, c)
}
}
}
69 changes: 67 additions & 2 deletions mpcs/src/basefold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,11 +482,20 @@ where

#[cfg(test)]
mod test {
use ff_ext::GoldilocksExt2;
use crate::util::codeword_fold_with_challenge;
use ff_ext::{BabyBearExt4, ExtensionField, FromUniformBytes, GoldilocksExt2, PoseidonField};
use itertools::Itertools;
use p3::{
babybear::BabyBear,
commit::Mmcs,
field::{Field, FieldAlgebra, TwoAdicField},
matrix::{Matrix, bitrev::BitReversableMatrix, dense::RowMajorMatrix},
};

use crate::{
basefold::Basefold,
basefold::{Basefold, poseidon2_merkle_tree},
test_util::{run_batch_commit_open_verify, run_commit_open_verify},
util::test::rand_vec,
};

use super::BasefoldRSParams;
Expand All @@ -509,4 +518,60 @@ mod test {
run_batch_commit_open_verify::<GoldilocksExt2, PcsGoldilocksRSCode>(20, 21, 64);
}
}

#[test]
fn test_fri_fold() {
type E = BabyBearExt4;
type F = BabyBear;

let mut rng = rand::thread_rng();
// fold a codeword of length 2^16 using random challenge
let codeword_log2_size = 16;
let codeword = E::random_vec(1 << codeword_log2_size, &mut rng);

let twiddle = F::GENERATOR.exp_power_of_2(F::TWO_ADICITY - codeword_log2_size);
let inv_2 = F::from_canonical_u64(2).inverse();

let challenge = E::random(&mut rng);

codeword
.chunks(2)
.zip(twiddle.powers())
.for_each(|(chunk, coeff)| {
codeword_fold_with_challenge(chunk, challenge, coeff, inv_2);
})
}

#[test]
fn test_bit_reverse() {
let v = (0..8).collect_vec();

let m = RowMajorMatrix::new(v, 1);
assert_eq!(
m.bit_reverse_rows().to_row_major_matrix().values,
vec![0b000, 0b100, 0b010, 0b110, 0b001, 0b101, 0b011, 0b111]
);
}

#[test]
fn test_poseidon2_mmcs() {
type E = BabyBearExt4;

let mut rng = rand::thread_rng();
let base_mmcs: <<E as ExtensionField>::BaseField as PoseidonField>::MMCS =
poseidon2_merkle_tree::<E>();

// commit to two matrices whose layouts are (2^10, 4) and (2^14, 10)
let matrices = vec![(1 << 10, 4), (1 << 14, 10)]
.into_iter()
.map(|(rows, cols)| {
RowMajorMatrix::<<E as ExtensionField>::BaseField>::new(
rand_vec(rows * cols, &mut rng),
cols,
)
})
.collect_vec();

base_mmcs.commit(matrices);
}
}
31 changes: 31 additions & 0 deletions mpcs/src/basefold/commit_phase.rs
Original file line number Diff line number Diff line change
Expand Up @@ -479,3 +479,34 @@ where

Some(next_challenge)
}

#[cfg(test)]
mod tests {
use ff_ext::{BabyBearExt4, FromUniformBytes};
use itertools::Itertools;
use p3::{
babybear::BabyBear,
matrix::{Matrix, dense::RowMajorMatrix},
};
use rand::thread_rng;

type E = BabyBearExt4;
type F = BabyBear;

#[test]
fn test_matrix_multiply_vector() {
let num_rows = 1 << 10;
let num_cols = 32;

let mut rng = thread_rng();
let matrix = RowMajorMatrix::new(F::random_vec(num_rows * num_cols, &mut rng), num_cols);
let v = E::random_vec(num_cols, &mut rng);

// matrix multiply vector
// codeword[i] = sum_j matrix[i][j] * v[j]
let _codeword = matrix
.rows()
.map(|row| v.iter().zip(row).map(|(a, b)| *a * b).sum::<E>())
.collect_vec();
}
}
31 changes: 28 additions & 3 deletions mpcs/src/basefold/encoding/rs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,18 +292,19 @@ where
mod tests {
use std::collections::VecDeque;

use ff_ext::GoldilocksExt2;
use ff_ext::{BabyBearExt4, GoldilocksExt2};
use itertools::izip;
use p3::{
commit::{ExtensionMmcs, Mmcs},
goldilocks::Goldilocks,
};

use rand::rngs::OsRng;
use rand::{rngs::OsRng, thread_rng};
use transcript::BasicTranscript;

use crate::{
basefold::commit_phase::basefold_fri_round, util::merkle_tree::poseidon2_merkle_tree,
basefold::commit_phase::basefold_fri_round,
util::{merkle_tree::poseidon2_merkle_tree, test::rand_vec},
};

use super::*;
Expand Down Expand Up @@ -369,4 +370,28 @@ mod tests {
&codeword_from_folded_rmm.values
);
}

#[test]
fn test_lde_batch() {
type E = BabyBearExt4;
let mut rng = thread_rng();
let dft: Radix2DitParallel<<E as ExtensionField>::BaseField> = Default::default();

let width = 10;
let added_bits = vec![1, 2, 3];
for log2_n in 1..22 {
let matrix = DenseMatrix::new(rand_vec(width * (1 << log2_n), &mut rng), width);
for added_bit in added_bits.iter() {
let dur = std::time::Instant::now();
dft.lde_batch(matrix.clone(), *added_bit);
println!(
"lde(matrix {}x{}, {}) took {:?}",
width,
1 << log2_n,
added_bit,
dur.elapsed()
);
}
}
}
}
36 changes: 36 additions & 0 deletions multilinear_extensions/src/mle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1253,3 +1253,39 @@ macro_rules! commutative_op_mle_pair {
commutative_op_mle_pair!(|$a, $b| $op, |out| out)
};
}

#[cfg(test)]
mod tests {
use crate::mle::{IntoMLE, MultilinearExtension};
use ff_ext::{BabyBearExt4, FromUniformBytes};
use itertools::Itertools;
use rand::thread_rng;

type E = BabyBearExt4;

#[test]
fn test_fix_var() {
let mut rng = thread_rng();
let mle: MultilinearExtension<'_, E> = MultilinearExtension::random(3, &mut rng);
let mle_clone = mle.clone();
let point = E::random(&mut rng);

let m1 = mle.fix_variables(&[point]);
let m2 = mle_clone
.as_view()
.get_base_field_vec()
.chunks(2)
.map(|chunk| {
// eq(1,r)*f(1) + eq(0,r)*f(0)
// r*f(1) + (1-r)*f(0)
let a = chunk[0];
let b = chunk[1];
point * (b - a) + a
})
.collect_vec()
.into_mle();

assert_eq!(m1.num_vars(), m2.num_vars());
assert_eq!(m1, m2,);
}
}