11#![ deny( clippy:: cargo) ]
22use ff_ext:: ExtensionField ;
3- use itertools:: Itertools ;
4- use multilinear_extensions:: mle:: { DenseMultilinearExtension , FieldType , MultilinearExtension } ;
3+ use itertools:: { Either , Itertools } ;
4+ use multilinear_extensions:: { mle:: { DenseMultilinearExtension , FieldType , MultilinearExtension } , virtual_poly :: { build_eq_x_r , eq_eval , VPAuxInfo } } ;
55use serde:: { Serialize , de:: DeserializeOwned } ;
66use std:: fmt:: Debug ;
77use transcript:: { BasicTranscript , Transcript } ;
88use util:: hash:: Digest ;
99use p3_field:: PrimeCharacteristicRing ;
10+ use multilinear_extensions:: virtual_poly:: VirtualPolynomial ;
11+ use sumcheck:: structs:: { IOPProof , IOPProverState , IOPVerifierState } ;
1012
1113pub mod sum_check;
1214pub mod util;
@@ -334,21 +336,50 @@ pub fn pcs_batch_open_diff_size<E: ExtensionField, Pcs: PolynomialCommitmentSche
334336 packed_comm : & Pcs :: CommitmentWithWitness ,
335337 final_comm : & Option < Pcs :: CommitmentWithWitness > ,
336338 points : & [ Vec < E > ] ,
337- evals : & [ E ] ,
339+ _evals : & [ E ] ,
338340 transcript : & mut impl Transcript < E > ,
339- ) -> Result < ( Pcs :: Proof , Option < Pcs :: Proof > ) , Error > {
341+ ) -> Result < ( IOPProof < E > , Vec < E > , Pcs :: Proof , Option < Pcs :: Proof > ) , Error > {
342+ assert_eq ! ( polys. len( ) , points. len( ) ) ;
340343 // TODO: Sort the polys by decreasing size
344+ let arc_polys: Vec < ArcMultilinearExtension < E > > = polys. into_iter ( ) . map ( |p| ArcMultilinearExtension :: from ( p. clone ( ) ) ) . collect ( ) ;
345+ // UNIFY SUMCHECK
346+ // Sample random coefficients for each poly
347+ let unify_coeffs = transcript. sample_vec ( polys. len ( ) ) ;
348+ // First convert each point into EQ
349+ let eq_points = points. iter ( ) . map ( |p| build_eq_x_r ( p) ) . collect :: < Vec < _ > > ( ) ;
350+
351+ let mut sumcheck_poly = VirtualPolynomial :: < E > :: new ( polys[ 0 ] . num_vars ( ) ) ;
352+ for ( ( eq, poly) , coeff) in eq_points. into_iter ( ) . zip ( arc_polys) . zip ( unify_coeffs) {
353+ let claim = match ( & poly. evaluations ( ) , & eq. evaluations ) {
354+ ( FieldType :: Base ( p) , FieldType :: Ext ( e) ) => {
355+ p. iter ( ) . zip ( e) . map ( |( p, e) | E :: from_bases ( & [ * p, E :: BaseField :: ZERO ] ) * * e) . fold ( E :: ZERO , |s, i| s + i)
356+ }
357+ _ => unreachable ! ( )
358+ } ;
359+ println ! ( "C: {:?}" , claim) ;
360+ sumcheck_poly. add_mle_list ( vec ! [ eq, poly] , coeff) ;
361+ }
362+ let ( unify_proof, unify_prover_state) = IOPProverState :: prove_batch_polys ( 1 , vec ! [ sumcheck_poly] , transcript) ;
363+ let packed_point = unify_proof. point . clone ( ) ;
364+ // sumcheck_poly is consisted of [eq, poly, eq, poly, ...], we only need the evaluations to `poly` here
365+ let sumcheck_evals = unify_prover_state. get_mle_final_evaluations ( ) ;
366+ let ( _, evals) : ( Vec < _ > , Vec < _ > ) = sumcheck_evals. into_iter ( ) . enumerate ( ) . partition_map ( |( i, e) | {
367+ if i % 2 == 0 {
368+ Either :: Left ( e)
369+ } else {
370+ Either :: Right ( e)
371+ }
372+ } ) ;
373+
374+ // GEN & EVAL PACK POLYS
341375 // TODO: The prover should be able to avoid packing the polys again
342376 let ( packed_polys, final_poly, packed_comps, final_comp) = pack_poly_prover ( polys) ;
343- // TODO: Add unifying sumcheck if the points do not match
344- // For now, assume that all polys are evaluated on the same points
345- let packed_point = points[ 0 ] . clone ( ) ;
377+ let packed_polys: Vec < ArcMultilinearExtension < E > > = packed_polys. into_iter ( ) . map ( |p| ArcMultilinearExtension :: from ( p) ) . collect ( ) ;
346378 // Note: the points are stored in reverse
347379 let final_point = if let Some ( final_poly) = & final_poly { packed_point[ ..final_poly. num_vars ] . to_vec ( ) } else { Vec :: new ( ) } ;
348380 // Use comps to compute evals for packed polys from regular evals
349- let ( packed_evals, final_eval) = compute_packed_eval ( & packed_point, & final_point, evals, & packed_comps, & final_comp) ;
381+ let ( packed_evals, final_eval) = compute_packed_eval ( & packed_point, & final_point, & evals, & packed_comps, & final_comp) ;
350382
351- let packed_polys: Vec < ArcMultilinearExtension < E > > = packed_polys. into_iter ( ) . map ( |p| ArcMultilinearExtension :: from ( p) ) . collect ( ) ;
352383 let pack_proof = Pcs :: simple_batch_open ( pp, & packed_polys, packed_comm, & packed_point, & packed_evals, transcript) ?;
353384 let final_proof = match ( & final_poly, & final_comm, & final_eval) {
354385 ( Some ( final_poly) , Some ( final_comm) , Some ( final_eval) ) => {
@@ -357,7 +388,7 @@ pub fn pcs_batch_open_diff_size<E: ExtensionField, Pcs: PolynomialCommitmentSche
357388 ( None , None , None ) => None ,
358389 _ => unreachable ! ( ) ,
359390 } ;
360- Ok ( ( pack_proof, final_proof) )
391+ Ok ( ( unify_proof , evals , pack_proof, final_proof) )
361392}
362393
363394pub fn pcs_verify < E : ExtensionField , Pcs : PolynomialCommitmentScheme < E > > (
@@ -391,22 +422,39 @@ pub fn pcs_batch_verify_diff_size<'a, E: ExtensionField, Pcs: PolynomialCommitme
391422 packed_comm : & Pcs :: Commitment ,
392423 final_comm : & Option < Pcs :: Commitment > ,
393424 points : & [ Vec < E > ] ,
394- evals : & [ E ] ,
425+ poly_evals : & [ E ] , // Evaluation of polys on original points
426+ unify_proof : & IOPProof < E > ,
427+ unify_evals : & [ E ] , // Evaluation of polys on unified points
395428 packed_proof : & Pcs :: Proof ,
396429 final_proof : & Option < Pcs :: Proof > ,
397430 transcript : & mut impl Transcript < E > ,
398431) -> Result < ( ) , Error >
399432where
400433 Pcs :: Commitment : ' a ,
401434{
435+ assert_eq ! ( poly_num_vars. len( ) , points. len( ) ) ;
436+ assert_eq ! ( poly_evals. len( ) , points. len( ) ) ;
437+ // UNIFY SUMCHECK
438+ // Sample random coefficients for each poly
439+ let unify_coeffs = transcript. sample_vec ( poly_num_vars. len ( ) ) ;
440+ let claim = poly_evals. iter ( ) . zip ( & unify_coeffs) . map ( |( e, c) | * e * * c) . sum ( ) ;
441+ let sumcheck_subclaim = IOPVerifierState :: verify ( claim, unify_proof, & VPAuxInfo { max_degree : 2 , max_num_variables : poly_num_vars[ 0 ] , phantom : Default :: default ( ) } , transcript) ;
442+ let packed_point = sumcheck_subclaim. point . iter ( ) . map ( |c| c. elements ) . collect :: < Vec < _ > > ( ) ;
443+ let claimed_eval = sumcheck_subclaim. expected_evaluation ;
444+ // Compute the evaluation of every EQ
445+ let eq_evals = points. iter ( ) . map ( |p| eq_eval ( p, & packed_point[ ..p. len ( ) ] ) ) ;
446+ let expected_eval = eq_evals. zip ( unify_evals) . zip ( unify_coeffs) . map ( |( ( eq, poly) , coeff) | eq * * poly * coeff) . sum ( ) ;
447+ assert_eq ! ( claimed_eval, expected_eval) ;
448+
449+ // VERIFY PACK POLYS
402450 // Replicate packing
403451 let ( _, final_poly_num_vars, packed_comps, final_comp) = pack_poly_verifier ( poly_num_vars) ;
404452 // TODO: Add unifying sumcheck if the points do not match
405453 // For now, assume that all polys are evaluated on the same points
406454 let packed_point = points[ 0 ] . clone ( ) ;
407455 let final_point = if let Some ( final_poly_num_vars) = & final_poly_num_vars { packed_point[ ..* final_poly_num_vars] . to_vec ( ) } else { Vec :: new ( ) } ;
408456 // Use comps to compute evals for packed polys from regular evals
409- let ( packed_evals, final_eval) = compute_packed_eval ( & packed_point, & final_point, evals , & packed_comps, & final_comp) ;
457+ let ( packed_evals, final_eval) = compute_packed_eval ( & packed_point, & final_point, unify_evals , & packed_comps, & final_comp) ;
410458
411459 Pcs :: simple_batch_verify ( vp, packed_comm, & packed_point, & packed_evals, packed_proof, transcript) ?;
412460 match ( & final_comm, & final_eval, & final_proof) {
@@ -964,7 +1012,7 @@ pub mod test_util {
9641012 assert ! ( max_num_vars > vars_gap * batch_size) ;
9651013 let ( pp, vp) = setup_pcs :: < E , Pcs > ( max_num_vars) ;
9661014
967- let ( poly_num_vars, packed_comm, final_comm, evals , packed_proof, final_proof, challenge) = {
1015+ let ( poly_num_vars, packed_comm, final_comm, poly_evals , unify_evals , unify_proof , packed_proof, final_proof, challenge) = {
9681016 let mut transcript = BasicTranscript :: new ( b"BaseFold" ) ;
9691017 let polys: Vec < DenseMultilinearExtension < E > > = ( 0 ..batch_size) . map ( |i|
9701018 gen_rand_polys ( |_| max_num_vars - i * vars_gap, 1 , gen_rand_poly)
@@ -975,12 +1023,14 @@ pub mod test_util {
9751023 let evals = polys. iter ( ) . zip ( & points) . map ( |( poly, point) | poly. evaluate ( point) ) . collect_vec ( ) ;
9761024 transcript. append_field_element_exts ( & evals) ;
9771025
978- let ( packed_proof, final_proof) = pcs_batch_open_diff_size :: < E , Pcs > ( & pp, & polys, & packed_comm, & final_comm, & points, & evals, & mut transcript) . unwrap ( ) ;
1026+ let ( unify_proof , unify_evals , packed_proof, final_proof) = pcs_batch_open_diff_size :: < E , Pcs > ( & pp, & polys, & packed_comm, & final_comm, & points, & evals, & mut transcript) . unwrap ( ) ;
9791027 (
9801028 polys. iter ( ) . map ( |p| p. num_vars ( ) ) . collect :: < Vec < _ > > ( ) ,
9811029 Pcs :: get_pure_commitment ( & packed_comm) ,
9821030 if let Some ( final_comm) = final_comm { Some ( Pcs :: get_pure_commitment ( & final_comm) ) } else { None } ,
9831031 evals,
1032+ unify_evals,
1033+ unify_proof,
9841034 packed_proof,
9851035 final_proof,
9861036 transcript. read_challenge ( ) ,
@@ -996,9 +1046,9 @@ pub mod test_util {
9961046
9971047 let point = get_point_from_challenge ( max_num_vars, & mut transcript) ;
9981048 let points: Vec < Vec < E > > = poly_num_vars. iter ( ) . map ( |n| point[ ..* n] . to_vec ( ) ) . collect ( ) ;
999- transcript. append_field_element_exts ( & evals ) ;
1049+ transcript. append_field_element_exts ( & poly_evals ) ;
10001050
1001- pcs_batch_verify_diff_size :: < E , Pcs > ( & vp, & poly_num_vars, & packed_comm, & final_comm, & points, & evals , & packed_proof, & final_proof, & mut transcript) . unwrap ( ) ;
1051+ pcs_batch_verify_diff_size :: < E , Pcs > ( & vp, & poly_num_vars, & packed_comm, & final_comm, & points, & poly_evals , & unify_proof , & unify_evals , & packed_proof, & final_proof, & mut transcript) . unwrap ( ) ;
10021052
10031053 let v_challenge = transcript. read_challenge ( ) ;
10041054 assert_eq ! ( challenge, v_challenge) ;
0 commit comments