Skip to content

Commit f8e6640

Browse files
committed
Buggy sumcheck?
1 parent 8631b15 commit f8e6640

File tree

5 files changed

+72
-17
lines changed

5 files changed

+72
-17
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mpcs/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ rand.workspace = true
3434
rand_chacha.workspace = true
3535
rayon = { workspace = true, optional = true }
3636
serde.workspace = true
37+
sumcheck = { path = "../sumcheck" }
3738
transcript = { path = "../transcript" }
3839
whir = { path = "../whir", features = ["ceno"] }
3940
zeroize = "1.8"

mpcs/src/lib.rs

Lines changed: 66 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
#![deny(clippy::cargo)]
22
use ff_ext::ExtensionField;
3-
use itertools::Itertools;
4-
use multilinear_extensions::mle::{DenseMultilinearExtension, FieldType, MultilinearExtension};
3+
use itertools::{Either, Itertools};
4+
use multilinear_extensions::{mle::{DenseMultilinearExtension, FieldType, MultilinearExtension}, virtual_poly::{build_eq_x_r, eq_eval, VPAuxInfo}};
55
use serde::{Serialize, de::DeserializeOwned};
66
use std::fmt::Debug;
77
use transcript::{BasicTranscript, Transcript};
88
use util::hash::Digest;
99
use p3_field::PrimeCharacteristicRing;
10+
use multilinear_extensions::virtual_poly::VirtualPolynomial;
11+
use sumcheck::structs::{IOPProof, IOPProverState, IOPVerifierState};
1012

1113
pub mod sum_check;
1214
pub mod util;
@@ -334,21 +336,50 @@ pub fn pcs_batch_open_diff_size<E: ExtensionField, Pcs: PolynomialCommitmentSche
334336
packed_comm: &Pcs::CommitmentWithWitness,
335337
final_comm: &Option<Pcs::CommitmentWithWitness>,
336338
points: &[Vec<E>],
337-
evals: &[E],
339+
_evals: &[E],
338340
transcript: &mut impl Transcript<E>,
339-
) -> Result<(Pcs::Proof, Option<Pcs::Proof>), Error> {
341+
) -> Result<(IOPProof<E>, Vec<E>, Pcs::Proof, Option<Pcs::Proof>), Error> {
342+
assert_eq!(polys.len(), points.len());
340343
// TODO: Sort the polys by decreasing size
344+
let arc_polys: Vec<ArcMultilinearExtension<E>> = polys.into_iter().map(|p| ArcMultilinearExtension::from(p.clone())).collect();
345+
// UNIFY SUMCHECK
346+
// Sample random coefficients for each poly
347+
let unify_coeffs = transcript.sample_vec(polys.len());
348+
// First convert each point into EQ
349+
let eq_points = points.iter().map(|p| build_eq_x_r(p)).collect::<Vec<_>>();
350+
351+
let mut sumcheck_poly = VirtualPolynomial::<E>::new(polys[0].num_vars());
352+
for ((eq, poly), coeff) in eq_points.into_iter().zip(arc_polys).zip(unify_coeffs) {
353+
let claim = match (&poly.evaluations(), &eq.evaluations) {
354+
(FieldType::Base(p), FieldType::Ext(e)) => {
355+
p.iter().zip(e).map(|(p, e)| E::from_bases(&[*p, E::BaseField::ZERO]) * *e).fold(E::ZERO, |s, i| s + i)
356+
}
357+
_ => unreachable!()
358+
};
359+
println!("C: {:?}", claim);
360+
sumcheck_poly.add_mle_list(vec![eq, poly], coeff);
361+
}
362+
let (unify_proof, unify_prover_state) = IOPProverState::prove_batch_polys(1, vec![sumcheck_poly], transcript);
363+
let packed_point = unify_proof.point.clone();
364+
// sumcheck_poly is consisted of [eq, poly, eq, poly, ...], we only need the evaluations to `poly` here
365+
let sumcheck_evals = unify_prover_state.get_mle_final_evaluations();
366+
let (_, evals): (Vec<_>, Vec<_>) = sumcheck_evals.into_iter().enumerate().partition_map(|(i, e)| {
367+
if i % 2 == 0 {
368+
Either::Left(e)
369+
} else {
370+
Either::Right(e)
371+
}
372+
});
373+
374+
// GEN & EVAL PACK POLYS
341375
// TODO: The prover should be able to avoid packing the polys again
342376
let (packed_polys, final_poly, packed_comps, final_comp) = pack_poly_prover(polys);
343-
// TODO: Add unifying sumcheck if the points do not match
344-
// For now, assume that all polys are evaluated on the same points
345-
let packed_point = points[0].clone();
377+
let packed_polys: Vec<ArcMultilinearExtension<E>> = packed_polys.into_iter().map(|p| ArcMultilinearExtension::from(p)).collect();
346378
// Note: the points are stored in reverse
347379
let final_point = if let Some(final_poly) = &final_poly { packed_point[..final_poly.num_vars].to_vec() } else { Vec::new() };
348380
// Use comps to compute evals for packed polys from regular evals
349-
let (packed_evals, final_eval) = compute_packed_eval(&packed_point, &final_point, evals, &packed_comps, &final_comp);
381+
let (packed_evals, final_eval) = compute_packed_eval(&packed_point, &final_point, &evals, &packed_comps, &final_comp);
350382

351-
let packed_polys: Vec<ArcMultilinearExtension<E>> = packed_polys.into_iter().map(|p| ArcMultilinearExtension::from(p)).collect();
352383
let pack_proof = Pcs::simple_batch_open(pp, &packed_polys, packed_comm, &packed_point, &packed_evals, transcript)?;
353384
let final_proof = match (&final_poly, &final_comm, &final_eval) {
354385
(Some(final_poly), Some(final_comm), Some(final_eval)) => {
@@ -357,7 +388,7 @@ pub fn pcs_batch_open_diff_size<E: ExtensionField, Pcs: PolynomialCommitmentSche
357388
(None, None, None) => None,
358389
_ => unreachable!(),
359390
};
360-
Ok((pack_proof, final_proof))
391+
Ok((unify_proof, evals, pack_proof, final_proof))
361392
}
362393

363394
pub fn pcs_verify<E: ExtensionField, Pcs: PolynomialCommitmentScheme<E>>(
@@ -391,22 +422,39 @@ pub fn pcs_batch_verify_diff_size<'a, E: ExtensionField, Pcs: PolynomialCommitme
391422
packed_comm: &Pcs::Commitment,
392423
final_comm: &Option<Pcs::Commitment>,
393424
points: &[Vec<E>],
394-
evals: &[E],
425+
poly_evals: &[E], // Evaluation of polys on original points
426+
unify_proof: &IOPProof<E>,
427+
unify_evals: &[E], // Evaluation of polys on unified points
395428
packed_proof: &Pcs::Proof,
396429
final_proof: &Option<Pcs::Proof>,
397430
transcript: &mut impl Transcript<E>,
398431
) -> Result<(), Error>
399432
where
400433
Pcs::Commitment: 'a,
401434
{
435+
assert_eq!(poly_num_vars.len(), points.len());
436+
assert_eq!(poly_evals.len(), points.len());
437+
// UNIFY SUMCHECK
438+
// Sample random coefficients for each poly
439+
let unify_coeffs = transcript.sample_vec(poly_num_vars.len());
440+
let claim = poly_evals.iter().zip(&unify_coeffs).map(|(e, c)| *e * *c).sum();
441+
let sumcheck_subclaim = IOPVerifierState::verify(claim, unify_proof, &VPAuxInfo { max_degree: 2, max_num_variables: poly_num_vars[0], phantom: Default::default() }, transcript);
442+
let packed_point = sumcheck_subclaim.point.iter().map(|c| c.elements).collect::<Vec<_>>();
443+
let claimed_eval = sumcheck_subclaim.expected_evaluation;
444+
// Compute the evaluation of every EQ
445+
let eq_evals = points.iter().map(|p| eq_eval(p, &packed_point[..p.len()]));
446+
let expected_eval = eq_evals.zip(unify_evals).zip(unify_coeffs).map(|((eq, poly), coeff)| eq * *poly * coeff).sum();
447+
assert_eq!(claimed_eval, expected_eval);
448+
449+
// VERIFY PACK POLYS
402450
// Replicate packing
403451
let (_, final_poly_num_vars, packed_comps, final_comp) = pack_poly_verifier(poly_num_vars);
404452
// TODO: Add unifying sumcheck if the points do not match
405453
// For now, assume that all polys are evaluated on the same points
406454
let packed_point = points[0].clone();
407455
let final_point = if let Some(final_poly_num_vars) = &final_poly_num_vars { packed_point[..*final_poly_num_vars].to_vec() } else { Vec::new() };
408456
// Use comps to compute evals for packed polys from regular evals
409-
let (packed_evals, final_eval) = compute_packed_eval(&packed_point, &final_point, evals, &packed_comps, &final_comp);
457+
let (packed_evals, final_eval) = compute_packed_eval(&packed_point, &final_point, unify_evals, &packed_comps, &final_comp);
410458

411459
Pcs::simple_batch_verify(vp, packed_comm, &packed_point, &packed_evals, packed_proof, transcript)?;
412460
match (&final_comm, &final_eval, &final_proof) {
@@ -964,7 +1012,7 @@ pub mod test_util {
9641012
assert!(max_num_vars > vars_gap * batch_size);
9651013
let (pp, vp) = setup_pcs::<E, Pcs>(max_num_vars);
9661014

967-
let (poly_num_vars, packed_comm, final_comm, evals, packed_proof, final_proof, challenge) = {
1015+
let (poly_num_vars, packed_comm, final_comm, poly_evals, unify_evals, unify_proof, packed_proof, final_proof, challenge) = {
9681016
let mut transcript = BasicTranscript::new(b"BaseFold");
9691017
let polys: Vec<DenseMultilinearExtension<E>> = (0..batch_size).map(|i|
9701018
gen_rand_polys(|_| max_num_vars - i * vars_gap, 1, gen_rand_poly)
@@ -975,12 +1023,14 @@ pub mod test_util {
9751023
let evals = polys.iter().zip(&points).map(|(poly, point)| poly.evaluate(point)).collect_vec();
9761024
transcript.append_field_element_exts(&evals);
9771025

978-
let (packed_proof, final_proof) = pcs_batch_open_diff_size::<E, Pcs>(&pp, &polys, &packed_comm, &final_comm, &points, &evals, &mut transcript).unwrap();
1026+
let (unify_proof, unify_evals, packed_proof, final_proof) = pcs_batch_open_diff_size::<E, Pcs>(&pp, &polys, &packed_comm, &final_comm, &points, &evals, &mut transcript).unwrap();
9791027
(
9801028
polys.iter().map(|p| p.num_vars()).collect::<Vec<_>>(),
9811029
Pcs::get_pure_commitment(&packed_comm),
9821030
if let Some(final_comm) = final_comm { Some(Pcs::get_pure_commitment(&final_comm)) } else { None },
9831031
evals,
1032+
unify_evals,
1033+
unify_proof,
9841034
packed_proof,
9851035
final_proof,
9861036
transcript.read_challenge(),
@@ -996,9 +1046,9 @@ pub mod test_util {
9961046

9971047
let point = get_point_from_challenge(max_num_vars, &mut transcript);
9981048
let points: Vec<Vec<E>> = poly_num_vars.iter().map(|n| point[..*n].to_vec()).collect();
999-
transcript.append_field_element_exts(&evals);
1049+
transcript.append_field_element_exts(&poly_evals);
10001050

1001-
pcs_batch_verify_diff_size::<E, Pcs>(&vp, &poly_num_vars, &packed_comm, &final_comm, &points, &evals, &packed_proof, &final_proof, &mut transcript).unwrap();
1051+
pcs_batch_verify_diff_size::<E, Pcs>(&vp, &poly_num_vars, &packed_comm, &final_comm, &points, &poly_evals, &unify_proof, &unify_evals, &packed_proof, &final_proof, &mut transcript).unwrap();
10021052

10031053
let v_challenge = transcript.read_challenge();
10041054
assert_eq!(challenge, v_challenge);

mpcs/src/whir.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ mod tests {
266266
gen_rand_poly,
267267
20,
268268
3,
269-
5,
269+
3,
270270
);
271271
}
272272
}

sumcheck/src/prover.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,9 @@ impl<'a, E: ExtensionField> IOPProverState<'a, E> {
439439
_ => unimplemented!("do not support degree {} > 5", products.len()),
440440
};
441441
exit_span!(span);
442+
if self.round == 1 {
443+
println!("SUM: {:?}", (sum[0] + sum[1]));
444+
}
442445
sum.iter_mut().for_each(|sum| *sum *= *coefficient);
443446

444447
let span = entered_span!("extrapolation");

0 commit comments

Comments
 (0)