Skip to content

Commit 81f1452

Browse files
committed
switch to new dag api
1 parent 1ba3f89 commit 81f1452

File tree

1 file changed

+51
-49
lines changed

1 file changed

+51
-49
lines changed

gkr_iop/src/gkr/layer/zerocheck_layer.rs

Lines changed: 51 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ use multilinear_extensions::{
66
mle::{IntoMLE, Point},
77
monomialize_expr_to_wit_terms,
88
utils::{
9-
eval_by_expr, eval_by_expr_with_instance, expr_compression_to_dag, expr_convert_to_witins,
9+
build_factored_dag_commutative, eval_by_expr, eval_by_expr_with_instance,
10+
expr_convert_to_witins,
1011
},
1112
virtual_poly::VPAuxInfo,
1213
};
@@ -167,61 +168,62 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
167168
self.n_fixed as WitnessId,
168169
self.n_instance,
169170
);
170-
tracing::debug!("main sumcheck degree: {}", zero_expr.degree());
171+
let zero_expr_degree = zero_expr.degree();
171172
self.main_sumcheck_expression = Some(zero_expr);
172-
self.main_sumcheck_expression_dag = Some({
173-
let (
174-
dag,
175-
instance_scalar_expr,
176-
challenges_expr,
177-
constant_expr,
178-
stack_top,
179-
(max_degree, max_dag_depth),
180-
) = expr_compression_to_dag(self.main_sumcheck_expression.as_ref().unwrap());
181-
182-
let mut num_add = 0;
183-
let mut num_mul = 0;
184-
185-
for node in &dag {
186-
match node.op {
187-
0 => (), // skip wit index
188-
1 => (), // skip scalar index
189-
2 => {
190-
num_add += 1;
191-
}
192-
3 => {
193-
num_mul += 1;
194-
}
195-
op => panic!("unknown op {op}"),
196-
}
197-
}
198-
199-
tracing::debug!(
200-
"layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \
201-
max_dag_depth {max_dag_depth} num_scalar {}",
202-
self.name,
203-
instance_scalar_expr.len() + challenges_expr.len() + constant_expr.len(),
204-
);
205-
206-
(
207-
dag,
208-
instance_scalar_expr,
209-
challenges_expr,
210-
constant_expr,
211-
stack_top,
212-
(max_degree, max_dag_depth),
213-
)
214-
});
215173
self.main_sumcheck_expression_monomial_terms = self
216174
.main_sumcheck_expression
217175
.as_ref()
218176
.map(|expr| expr.get_monomial_terms());
219-
tracing::debug!(
220-
"main sumcheck monomial terms count: {}",
177+
178+
{
221179
self.main_sumcheck_expression_monomial_terms
222180
.as_ref()
223-
.map_or(0, |terms| terms.len()),
224-
);
181+
.map(|terms| {
182+
let num_mul: usize = terms.iter().map(|term| term.product.len()).sum();
183+
let num_add = terms.iter().len() - 1;
184+
185+
tracing::debug!(
186+
"layer name {} monomial num_add: {num_add} num_mul: {num_mul}",
187+
self.name,
188+
);
189+
});
190+
}
191+
192+
self.main_sumcheck_expression_dag = {
193+
self.main_sumcheck_expression_monomial_terms
194+
.as_ref()
195+
.map(|terms| {
196+
let (dag, coeffs, Some(final_out_index), max_dag_depth) = build_factored_dag_commutative(terms) else { panic!() };
197+
let stack_top = final_out_index + 1;
198+
let max_degree = zero_expr_degree;
199+
200+
let mut num_add = 0;
201+
let mut num_mul = 0;
202+
203+
for node in &dag {
204+
match node.op {
205+
0 => (), // skip wit index
206+
1 => (), // skip scalar index
207+
2 => {
208+
num_add += 1;
209+
}
210+
3 => {
211+
num_mul += 1;
212+
}
213+
op => panic!("unknown op {op}"),
214+
}
215+
}
216+
217+
tracing::debug!(
218+
"layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \
219+
max_dag_depth {max_dag_depth} num_scalar {} final_out_index {final_out_index}",
220+
self.name,
221+
coeffs.len(),
222+
);
223+
(dag, coeffs, stack_top, max_dag_depth as usize, zero_expr_degree)
224+
})
225+
};
226+
225227
exit_span!(span);
226228
}
227229

0 commit comments

Comments
 (0)