Skip to content

Commit d09d75b

Browse files
committed
switch to new dag api
1 parent 1ba3f89 commit d09d75b

File tree

3 files changed

+65
-84
lines changed

3 files changed

+65
-84
lines changed

gkr_iop/src/gkr/layer.rs

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,8 @@ pub struct Layer<E: ExtensionField> {
105105
pub main_sumcheck_expression: Option<Expression<E>>,
106106

107107
// flatten computation dag
108-
pub main_sumcheck_expression_dag: Option<(
109-
Vec<Node>,
110-
Vec<Instance>,
111-
Vec<Expression<E>>,
112-
Vec<Either<E::BaseField, E>>,
113-
u32,
114-
(usize, usize),
115-
)>,
108+
// (dag, coeffs, final_out_index, max_dag_depth, max_degree)
109+
pub main_sumcheck_expression_dag: Option<(Vec<Node>, Vec<Expression<E>>, u32, usize, usize)>,
116110

117111
// rotation sumcheck expression, only optionally valid for zerocheck
118112
// store in 2 forms: expression & monomial

gkr_iop/src/gkr/layer/gpu/mod.rs

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ use ff_ext::ExtensionField;
1515
use itertools::{Itertools, chain};
1616
use mpcs::PolynomialCommitmentScheme;
1717
use multilinear_extensions::{
18-
Expression, Instance,
18+
Expression,
1919
mle::{MultilinearExtension, Point},
2020
monomial::Term,
2121
utils::eval_by_expr_constant,
@@ -235,35 +235,19 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
235235
);
236236

237237
let (proof_gpu, evals_gpu, challenges_gpu) = if layer.exprs.len() > 200 {
238-
let (
239-
dag,
240-
instance_scalar_expr,
241-
challenges_expr,
242-
constant_expr,
243-
stack_top,
244-
(max_degree, max_dag_depth),
245-
) = layer.main_sumcheck_expression_dag.as_ref().unwrap();
238+
// (dag, coeffs, final_out_index, max_dag_depth, max_degree)
239+
let (dag, coeffs, stack_top, max_dag_depth, max_degree) =
240+
layer.main_sumcheck_expression_dag.as_ref().unwrap();
246241

242+
let pub_io_eval_scalar = pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec();
247243
// format: pub_io ++ challenge ++ constant
248-
let term_coefficients = instance_scalar_expr
244+
let term_coefficients = coeffs
249245
.iter()
250-
.map(|Instance(id)| pub_io_evals[*id])
251-
.chain(
252-
challenges_expr
253-
.iter()
254-
.map(|c| {
255-
eval_by_expr_constant(
256-
&pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec(),
257-
&main_sumcheck_challenges,
258-
c,
259-
)
260-
})
261-
.chain(constant_expr.iter().copied())
262-
.map(|either_v| match either_v {
263-
Either::Left(base_field_val) => E::from(base_field_val),
264-
Either::Right(ext_field_val) => ext_field_val,
265-
}),
266-
)
246+
.map(|c| eval_by_expr_constant(&pub_io_eval_scalar, &main_sumcheck_challenges, c))
247+
.map(|either_v| match either_v {
248+
Either::Left(base_field_val) => E::from(base_field_val),
249+
Either::Right(ext_field_val) => ext_field_val,
250+
})
267251
.collect_vec();
268252

269253
let max_num_var = max_num_variables;

gkr_iop/src/gkr/layer/zerocheck_layer.rs

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use multilinear_extensions::{
66
mle::{IntoMLE, Point},
77
monomialize_expr_to_wit_terms,
88
utils::{
9-
eval_by_expr, eval_by_expr_with_instance, expr_compression_to_dag, expr_convert_to_witins,
9+
build_factored_dag_commutative, eval_by_expr, eval_by_expr_with_instance,
10+
expr_convert_to_witins,
1011
},
1112
virtual_poly::VPAuxInfo,
1213
};
@@ -167,61 +168,63 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
167168
self.n_fixed as WitnessId,
168169
self.n_instance,
169170
);
170-
tracing::debug!("main sumcheck degree: {}", zero_expr.degree());
171+
let zero_expr_degree = zero_expr.degree();
171172
self.main_sumcheck_expression = Some(zero_expr);
172-
self.main_sumcheck_expression_dag = Some({
173-
let (
174-
dag,
175-
instance_scalar_expr,
176-
challenges_expr,
177-
constant_expr,
178-
stack_top,
179-
(max_degree, max_dag_depth),
180-
) = expr_compression_to_dag(self.main_sumcheck_expression.as_ref().unwrap());
181-
182-
let mut num_add = 0;
183-
let mut num_mul = 0;
184-
185-
for node in &dag {
186-
match node.op {
187-
0 => (), // skip wit index
188-
1 => (), // skip scalar index
189-
2 => {
190-
num_add += 1;
191-
}
192-
3 => {
193-
num_mul += 1;
194-
}
195-
op => panic!("unknown op {op}"),
196-
}
197-
}
198-
199-
tracing::debug!(
200-
"layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \
201-
max_dag_depth {max_dag_depth} num_scalar {}",
202-
self.name,
203-
instance_scalar_expr.len() + challenges_expr.len() + constant_expr.len(),
204-
);
205-
206-
(
207-
dag,
208-
instance_scalar_expr,
209-
challenges_expr,
210-
constant_expr,
211-
stack_top,
212-
(max_degree, max_dag_depth),
213-
)
214-
});
215173
self.main_sumcheck_expression_monomial_terms = self
216174
.main_sumcheck_expression
217175
.as_ref()
218176
.map(|expr| expr.get_monomial_terms());
219-
tracing::debug!(
220-
"main sumcheck monomial terms count: {}",
177+
178+
{
221179
self.main_sumcheck_expression_monomial_terms
222180
.as_ref()
223-
.map_or(0, |terms| terms.len()),
224-
);
181+
.map(|terms| {
182+
let num_mul: usize = terms.iter().map(|term| term.product.len()).sum();
183+
let num_add = terms.iter().len() - 1;
184+
185+
tracing::debug!(
186+
"layer name {} monomial num_add: {num_add} num_mul: {num_mul}",
187+
self.name,
188+
);
189+
});
190+
}
191+
192+
self.main_sumcheck_expression_dag = {
193+
self.main_sumcheck_expression_monomial_terms
194+
.as_ref()
195+
.map(|terms| {
196+
// selector are structural witin, which is used to be the largest id.
197+
let (dag, coeffs, Some(final_out_index), max_dag_depth) = build_factored_dag_commutative(terms, false) else { panic!() };
198+
let stack_top = final_out_index + 1;
199+
let max_degree = zero_expr_degree;
200+
201+
let mut num_add = 0;
202+
let mut num_mul = 0;
203+
204+
for node in &dag {
205+
match node.op {
206+
0 => (), // skip wit index
207+
1 => (), // skip scalar index
208+
2 => {
209+
num_add += 1;
210+
}
211+
3 => {
212+
num_mul += 1;
213+
}
214+
op => panic!("unknown op {op}"),
215+
}
216+
}
217+
218+
tracing::debug!(
219+
"layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \
220+
max_dag_depth {max_dag_depth} num_scalar {} final_out_index {final_out_index}",
221+
self.name,
222+
coeffs.len(),
223+
);
224+
(dag, coeffs, stack_top, max_dag_depth as usize, zero_expr_degree)
225+
})
226+
};
227+
225228
exit_span!(span);
226229
}
227230

0 commit comments

Comments
 (0)