Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

20 changes: 10 additions & 10 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,16 @@ repository = "https://github.com/scroll-tech/ceno"
version = "0.1.0"

[workspace.dependencies]
ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", tag = "v1.0.0-alpha.13" }
mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", tag = "v1.0.0-alpha.13" }
multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", tag = "v1.0.0-alpha.13" }
p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", tag = "v1.0.0-alpha.13" }
poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", tag = "v1.0.0-alpha.13" }
sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", tag = "v1.0.0-alpha.13" }
sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", tag = "v1.0.0-alpha.13" }
transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", tag = "v1.0.0-alpha.13" }
whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", tag = "v1.0.0-alpha.13" }
witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", tag = "v1.0.0-alpha.13" }
ff_ext = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "ff_ext", branch = "feat/arithmetics" }
mpcs = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "mpcs", branch = "feat/arithmetics" }
multilinear_extensions = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "multilinear_extensions", branch = "feat/arithmetics" }
p3 = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "p3", branch = "feat/arithmetics" }
poseidon = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "poseidon", branch = "feat/arithmetics" }
sp1-curves = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sp1-curves", branch = "feat/arithmetics" }
sumcheck = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "sumcheck", branch = "feat/arithmetics" }
transcript = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "transcript", branch = "feat/arithmetics" }
whir = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "whir", branch = "feat/arithmetics" }
witness = { git = "https://github.com/scroll-tech/gkr-backend.git", package = "witness", branch = "feat/arithmetics" }

alloy-primitives = "1.3"
anyhow = { version = "1.0", default-features = false }
Expand Down
6 changes: 6 additions & 0 deletions gkr_iop/src/gkr/layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use multilinear_extensions::{
Expression, Instance, StructuralWitIn, ToExpr,
mle::{Point, PointAndEval},
monomial::Term,
utils::Node,
};
use p3::field::FieldAlgebra;
use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator};
Expand Down Expand Up @@ -103,6 +104,10 @@ pub struct Layer<E: ExtensionField> {
pub main_sumcheck_expression_monomial_terms: Option<Vec<Term<Expression<E>, Expression<E>>>>,
pub main_sumcheck_expression: Option<Expression<E>>,

// flatten computation dag
// (dag, coeffs, final_out_index, max_dag_depth, max_degree)
pub main_sumcheck_expression_dag: Option<(Vec<Node>, Vec<Expression<E>>, u32, usize, usize)>,

// rotation sumcheck expression, only optionally valid for zerocheck
// store in 2 forms: expression & monomial
pub rotation_sumcheck_expression_monomial_terms:
Expand Down Expand Up @@ -175,6 +180,7 @@ impl<E: ExtensionField> Layer<E> {
expr_names,
main_sumcheck_expression_monomial_terms: None,
main_sumcheck_expression: None,
main_sumcheck_expression_dag: None,
rotation_sumcheck_expression_monomial_terms: None,
rotation_sumcheck_expression: None,
};
Expand Down
48 changes: 38 additions & 10 deletions gkr_iop/src/gkr/layer/gpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use multilinear_extensions::{
Expression,
mle::{MultilinearExtension, Point},
monomial::Term,
utils::eval_by_expr_constant,
};
use rayon::{
iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator},
Expand Down Expand Up @@ -232,8 +233,26 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
layer.n_fixed,
layer.n_instance,
);

// process dag
// (dag, coeffs, final_out_index, max_dag_depth, max_degree)
let (dag, dag_coeffs, stack_top, max_dag_depth, max_degree) =
layer.main_sumcheck_expression_dag.as_ref().unwrap();

let pub_io_eval_scalar = pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec();
// format: pub_io ++ challenge ++ constant
let dag_coeffs = dag_coeffs
.iter()
.map(|c| eval_by_expr_constant(&pub_io_eval_scalar, &main_sumcheck_challenges, c))
.map(|either_v| match either_v {
Either::Left(base_field_val) => E::from(base_field_val),
Either::Right(ext_field_val) => ext_field_val,
})
.collect_vec();

// process monomial terms
// Calculate max_num_var and max_degree from the extracted relationships
let (term_coefficients, mle_indices_per_term, mle_size_info) =
let (monomial_coefficients, mle_indices_per_term, mle_size_info) =
extract_mle_relationships_from_monomial_terms(
&layer
.main_sumcheck_expression_monomial_terms
Expand All @@ -243,18 +262,18 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
&pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec(),
&main_sumcheck_challenges,
);

let max_num_var = max_num_variables;
let max_degree = mle_indices_per_term
.iter()
.map(|indices| indices.len())
.max()
.unwrap_or(0);

// Convert types for GPU function Call
let monomial_coefficients: Vec<BB31Ext> =
unsafe { std::mem::transmute(monomial_coefficients) };

// Convert types for GPU function Call
let basic_tr: &mut BasicTranscript<BB31Ext> =
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<BB31Ext>) };
let term_coefficients_gl64: Vec<BB31Ext> =
unsafe { std::mem::transmute(term_coefficients) };
let dag_coeffs: Vec<BB31Ext> = unsafe { std::mem::transmute(dag_coeffs) };

let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<BB31Ext>> =
unsafe { std::mem::transmute(all_witins_gpu) };
let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec();
Expand All @@ -264,13 +283,18 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
&cuda_hal,
all_witins_gpu_type_gl64,
&mle_size_info,
&term_coefficients_gl64,
&monomial_coefficients,
&mle_indices_per_term,
max_num_var,
max_degree,
*max_degree,
dag,
*max_dag_depth,
&dag_coeffs,
*stack_top,
basic_tr,
)
.unwrap();

let evals_gpu = evals_gpu.into_iter().flatten().collect_vec();
let row_challenges = challenges_gpu.iter().map(|c| c.elements).collect_vec();

Expand Down Expand Up @@ -389,6 +413,10 @@ pub(crate) fn prove_rotation_gpu<E: ExtensionField, PCS: PolynomialCommitmentSch
&mle_indices_per_term,
max_num_var,
max_degree,
&[],
0,
&[],
0,
basic_tr,
)
.unwrap();
Expand Down
60 changes: 54 additions & 6 deletions gkr_iop/src/gkr/layer/zerocheck_layer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use multilinear_extensions::{
macros::{entered_span, exit_span},
mle::{IntoMLE, Point},
monomialize_expr_to_wit_terms,
utils::{eval_by_expr, eval_by_expr_with_instance, expr_convert_to_witins},
utils::{
build_factored_dag_commutative, eval_by_expr, eval_by_expr_with_instance,
expr_convert_to_witins,
},
virtual_poly::VPAuxInfo,
};
use p3::field::{FieldAlgebra, dot_product};
Expand Down Expand Up @@ -165,18 +168,63 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
self.n_fixed as WitnessId,
self.n_instance,
);
tracing::debug!("main sumcheck degree: {}", zero_expr.degree());
let zero_expr_degree = zero_expr.degree();
self.main_sumcheck_expression = Some(zero_expr);
self.main_sumcheck_expression_monomial_terms = self
.main_sumcheck_expression
.as_ref()
.map(|expr| expr.get_monomial_terms());
tracing::debug!(
"main sumcheck monomial terms count: {}",

{
self.main_sumcheck_expression_monomial_terms
.as_ref()
.map_or(0, |terms| terms.len()),
);
.map(|terms| {
let num_mul: usize = terms.iter().map(|term| term.product.len()).sum();
let num_add = terms.iter().len() - 1;

tracing::debug!(
"layer name {} monomial num_add: {num_add} num_mul: {num_mul}",
self.name,
);
});
}

self.main_sumcheck_expression_dag = {
self.main_sumcheck_expression_monomial_terms
.as_ref()
.map(|terms| {
// selector are structural witin, which is used to be the largest id.
let (dag, coeffs, Some(final_out_index), max_dag_depth) = build_factored_dag_commutative(terms, false) else { panic!() };
let stack_top = final_out_index + 1;
let max_degree = zero_expr_degree;

let mut num_add = 0;
let mut num_mul = 0;

for node in &dag {
match node.op {
0 => (), // skip wit index
1 => (), // skip scalar index
2 => {
num_add += 1;
}
3 => {
num_mul += 1;
}
op => panic!("unknown op {op}"),
}
}

tracing::debug!(
"layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \
max_dag_depth {max_dag_depth} num_scalar {} final_out_index {final_out_index}",
self.name,
coeffs.len(),
);
(dag, coeffs, stack_top, max_dag_depth as usize, zero_expr_degree)
})
};

exit_span!(span);
}

Expand Down
Loading