diff --git a/Cargo.lock b/Cargo.lock index d319c82f2..4a202ce51 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1904,7 +1904,7 @@ dependencies = [ [[package]] name = "ff_ext" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "once_cell", "p3", @@ -2716,7 +2716,7 @@ dependencies = [ [[package]] name = "mpcs" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "bincode", "clap", @@ -2740,7 +2740,7 @@ dependencies = [ [[package]] name = "multilinear_extensions" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "either", "ff_ext", @@ -3061,7 +3061,7 @@ dependencies = [ [[package]] name = "p3" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "p3-air", "p3-baby-bear", @@ -3498,7 +3498,7 @@ dependencies = [ [[package]] name = "poseidon" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "ff_ext", "p3", @@ -4482,7 +4482,7 @@ dependencies = [ [[package]] name = "sp1-curves" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "cfg-if", "dashu", @@ -4604,7 +4604,7 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "sumcheck" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "either", "ff_ext", @@ -4622,7 +4622,7 @@ dependencies = [ [[package]] name = "sumcheck_macro" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "itertools 0.13.0", "p3", @@ -5017,7 +5017,7 @@ dependencies = [ [[package]] name = "transcript" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "ff_ext", "itertools 0.13.0", @@ -5289,7 +5289,7 @@ dependencies = [ [[package]] name = "whir" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "bincode", "clap", @@ -5576,7 +5576,7 @@ dependencies = [ [[package]] name = "witness" version = "0.1.0" -source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.13#89aa6add9f4d16cd2f10ec81f7c11d4507400c9b" +source = "git+https://github.com/scroll-tech/gkr-backend.git?tag=v1.0.0-alpha.14#08462f562a091792262ddb63eb5b774f9896be77" dependencies = [ "ff_ext", "multilinear_extensions", diff --git a/Cargo.toml b/Cargo.toml index a0b824a62..8a573fc94 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,16 +23,16 @@ repository = "https://github.com/scroll-tech/ceno" version = "0.1.0" [workspace.dependencies] -ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.13" } -mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.13" } -multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.13" } -p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.13" } -poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.13" } -sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.13" } -sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.13" } -transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.13" } -whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.13" } -witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.13" } +ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.14" } +mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.14" } +multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.14" } +p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.14" } +poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.14" } +sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.14" } +sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.14" } +transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.14" } +whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.14" } +witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.14" } alloy-primitives = "1.3" anyhow = { version = "1.0", default-features = false } @@ -97,17 +97,17 @@ opt-level = 3 [profile.release] lto = "thin" -# [patch."ssh://git@github.com/scroll-tech/ceno-gpu.git"] -# ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features=["bb31"] } - -# [patch."https://github.com/scroll-tech/gkr-backend"] -# ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } -# mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } -# multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } -# p3 = { path = "../gkr-backend/crates/p3", package = "p3" } -# poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } -# sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } -# sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } -# transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } -# whir = { path = "../gkr-backend/crates/whir", package = "whir" } -# witness = { path = "../gkr-backend/crates/witness", package = "witness" } +#[patch."ssh://git@github.com/scroll-tech/ceno-gpu.git"] +#ceno_gpu = { path = "../ceno-gpu/cuda_hal", package = "cuda_hal", default-features = false, features = ["bb31"] } +# +#[patch."https://github.com/scroll-tech/gkr-backend"] +#ff_ext = { path = "../gkr-backend/crates/ff_ext", package = "ff_ext" } +#mpcs = { path = "../gkr-backend/crates/mpcs", package = "mpcs" } +#multilinear_extensions = { path = "../gkr-backend/crates/multilinear_extensions", package = "multilinear_extensions" } +#p3 = { path = "../gkr-backend/crates/p3", package = "p3" } +#poseidon = { path = "../gkr-backend/crates/poseidon", package = "poseidon" } +#sp1-curves = { path = "../gkr-backend/crates/curves", package = "sp1-curves" } +#sumcheck = { path = "../gkr-backend/crates/sumcheck", package = "sumcheck" } +#transcript = { path = "../gkr-backend/crates/transcript", package = "transcript" } +#whir = { path = "../gkr-backend/crates/whir", package = "whir" } +#witness = { path = "../gkr-backend/crates/witness", package = "witness" } diff --git a/gkr_iop/src/gkr/layer.rs b/gkr_iop/src/gkr/layer.rs index 64eb747be..9f23a40ba 100644 --- a/gkr_iop/src/gkr/layer.rs +++ b/gkr_iop/src/gkr/layer.rs @@ -6,6 +6,7 @@ use multilinear_extensions::{ Expression, Instance, StructuralWitIn, ToExpr, mle::{Point, PointAndEval}, monomial::Term, + utils::Node, }; use p3::field::FieldAlgebra; use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator}; @@ -48,6 +49,8 @@ pub enum LayerType { Linear, } +pub type DagInfo = (Vec, Vec>, u32, usize, usize); + #[derive(Clone, Debug, Serialize, Deserialize)] #[serde(bound( serialize = "E::BaseField: Serialize", @@ -103,6 +106,10 @@ pub struct Layer { pub main_sumcheck_expression_monomial_terms: Option, Expression>>>, pub main_sumcheck_expression: Option>, + // flatten computation dag + // (dag, coeffs, final_out_index, max_dag_depth, max_degree) + pub main_sumcheck_expression_dag: Option>, + // rotation sumcheck expression, only optionally valid for zerocheck // store in 2 forms: expression & monomial pub rotation_sumcheck_expression_monomial_terms: @@ -175,6 +182,7 @@ impl Layer { expr_names, main_sumcheck_expression_monomial_terms: None, main_sumcheck_expression: None, + main_sumcheck_expression_dag: None, rotation_sumcheck_expression_monomial_terms: None, rotation_sumcheck_expression: None, }; diff --git a/gkr_iop/src/gkr/layer/gpu/mod.rs b/gkr_iop/src/gkr/layer/gpu/mod.rs index 5840111e7..ee0eb05c4 100644 --- a/gkr_iop/src/gkr/layer/gpu/mod.rs +++ b/gkr_iop/src/gkr/layer/gpu/mod.rs @@ -18,6 +18,7 @@ use multilinear_extensions::{ Expression, mle::{MultilinearExtension, Point}, monomial::Term, + utils::eval_by_expr_constant, }; use rayon::{ iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator}, @@ -232,8 +233,26 @@ impl> ZerocheckLayerProver layer.n_fixed, layer.n_instance, ); + + // process dag + // (dag, coeffs, final_out_index, max_dag_depth, max_degree) + let (dag, dag_coeffs, final_out_index, max_dag_depth, max_degree) = + layer.main_sumcheck_expression_dag.as_ref().unwrap(); + + let pub_io_eval_scalar = pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec(); + // format: pub_io ++ challenge ++ constant + let dag_coeffs = dag_coeffs + .iter() + .map(|c| eval_by_expr_constant(&pub_io_eval_scalar, &main_sumcheck_challenges, c)) + .map(|either_v| match either_v { + Either::Left(base_field_val) => E::from(base_field_val), + Either::Right(ext_field_val) => ext_field_val, + }) + .collect_vec(); + + // process monomial terms // Calculate max_num_var and max_degree from the extracted relationships - let (term_coefficients, mle_indices_per_term, mle_size_info) = + let (monomial_coefficients, mle_indices_per_term, mle_size_info) = extract_mle_relationships_from_monomial_terms( &layer .main_sumcheck_expression_monomial_terms @@ -243,18 +262,18 @@ impl> ZerocheckLayerProver &pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec(), &main_sumcheck_challenges, ); + let max_num_var = max_num_variables; - let max_degree = mle_indices_per_term - .iter() - .map(|indices| indices.len()) - .max() - .unwrap_or(0); + + // Convert types for GPU function Call + let monomial_coefficients: Vec = + unsafe { std::mem::transmute(monomial_coefficients) }; // Convert types for GPU function Call let basic_tr: &mut BasicTranscript = unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript) }; - let term_coefficients_gl64: Vec = - unsafe { std::mem::transmute(term_coefficients) }; + let dag_coeffs: Vec = unsafe { std::mem::transmute(dag_coeffs) }; + let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu> = unsafe { std::mem::transmute(all_witins_gpu) }; let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec(); @@ -264,13 +283,18 @@ impl> ZerocheckLayerProver &cuda_hal, all_witins_gpu_type_gl64, &mle_size_info, - &term_coefficients_gl64, + &monomial_coefficients, &mle_indices_per_term, max_num_var, - max_degree, + *max_degree, + dag, + *max_dag_depth, + &dag_coeffs, + *final_out_index, basic_tr, ) .unwrap(); + let evals_gpu = evals_gpu.into_iter().flatten().collect_vec(); let row_challenges = challenges_gpu.iter().map(|c| c.elements).collect_vec(); @@ -389,6 +413,10 @@ pub(crate) fn prove_rotation_gpu ZerocheckLayer for Layer { self.n_fixed as WitnessId, self.n_instance, ); - tracing::debug!("main sumcheck degree: {}", zero_expr.degree()); + let zero_expr_degree = zero_expr.degree(); self.main_sumcheck_expression = Some(zero_expr); self.main_sumcheck_expression_monomial_terms = self .main_sumcheck_expression .as_ref() .map(|expr| expr.get_monomial_terms()); - tracing::debug!( - "main sumcheck monomial terms count: {}", + + { + if let Some(terms) = self.main_sumcheck_expression_monomial_terms.as_ref() { + let num_mul: usize = terms.iter().map(|term| term.product.len()).sum(); + let num_add = terms.iter().len() - 1; + + tracing::debug!( + "layer name {} monomial num_add: {num_add} num_mul: {num_mul}", + self.name, + ); + } + } + + self.main_sumcheck_expression_dag = { self.main_sumcheck_expression_monomial_terms .as_ref() - .map_or(0, |terms| terms.len()), - ); + .map(|terms| { + // selector are structural witin, which is used to be the largest id. + let (dag, coeffs, Some(final_out_index), max_dag_depth) = build_factored_dag_commutative(terms, false) else { panic!() }; + let max_degree = zero_expr_degree; + + let (num_add, num_mul) = dag_stats(&dag); + tracing::debug!( + "layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \ + max_dag_depth {max_dag_depth} num_scalar {} final_out_index {final_out_index}", + self.name, + coeffs.len(), + ); + (dag, coeffs, final_out_index, max_dag_depth as usize, zero_expr_degree) + }) + }; + exit_span!(span); }