@@ -6,7 +6,8 @@ use multilinear_extensions::{
66 mle:: { IntoMLE , Point } ,
77 monomialize_expr_to_wit_terms,
88 utils:: {
9- eval_by_expr, eval_by_expr_with_instance, expr_compression_to_dag, expr_convert_to_witins,
9+ build_factored_dag_commutative, eval_by_expr, eval_by_expr_with_instance,
10+ expr_convert_to_witins,
1011 } ,
1112 virtual_poly:: VPAuxInfo ,
1213} ;
@@ -167,61 +168,62 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
167168 self . n_fixed as WitnessId ,
168169 self . n_instance ,
169170 ) ;
170- tracing :: debug! ( "main sumcheck degree: {}" , zero_expr. degree( ) ) ;
171+ let zero_expr_degree = zero_expr. degree ( ) ;
171172 self . main_sumcheck_expression = Some ( zero_expr) ;
172- self . main_sumcheck_expression_dag = Some ( {
173- let (
174- dag,
175- instance_scalar_expr,
176- challenges_expr,
177- constant_expr,
178- stack_top,
179- ( max_degree, max_dag_depth) ,
180- ) = expr_compression_to_dag ( self . main_sumcheck_expression . as_ref ( ) . unwrap ( ) ) ;
181-
182- let mut num_add = 0 ;
183- let mut num_mul = 0 ;
184-
185- for node in & dag {
186- match node. op {
187- 0 => ( ) , // skip wit index
188- 1 => ( ) , // skip scalar index
189- 2 => {
190- num_add += 1 ;
191- }
192- 3 => {
193- num_mul += 1 ;
194- }
195- op => panic ! ( "unknown op {op}" ) ,
196- }
197- }
198-
199- tracing:: debug!(
200- "layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \
201- max_dag_depth {max_dag_depth} num_scalar {}",
202- self . name,
203- instance_scalar_expr. len( ) + challenges_expr. len( ) + constant_expr. len( ) ,
204- ) ;
205-
206- (
207- dag,
208- instance_scalar_expr,
209- challenges_expr,
210- constant_expr,
211- stack_top,
212- ( max_degree, max_dag_depth) ,
213- )
214- } ) ;
215173 self . main_sumcheck_expression_monomial_terms = self
216174 . main_sumcheck_expression
217175 . as_ref ( )
218176 . map ( |expr| expr. get_monomial_terms ( ) ) ;
219- tracing :: debug! (
220- "main sumcheck monomial terms count: {}" ,
177+
178+ {
221179 self . main_sumcheck_expression_monomial_terms
222180 . as_ref ( )
223- . map_or( 0 , |terms| terms. len( ) ) ,
224- ) ;
181+ . map ( |terms| {
182+ let num_mul: usize = terms. iter ( ) . map ( |term| term. product . len ( ) ) . sum ( ) ;
183+ let num_add = terms. iter ( ) . len ( ) - 1 ;
184+
185+ tracing:: debug!(
186+ "layer name {} monomial num_add: {num_add} num_mul: {num_mul}" ,
187+ self . name,
188+ ) ;
189+ } ) ;
190+ }
191+
192+ self . main_sumcheck_expression_dag = {
193+ self . main_sumcheck_expression_monomial_terms
194+ . as_ref ( )
195+ . map ( |terms| {
196+ let ( dag, coeffs, Some ( final_out_index) , max_dag_depth) = build_factored_dag_commutative ( terms) else { panic ! ( ) } ;
197+ let stack_top = final_out_index + 1 ;
198+ let max_degree = zero_expr_degree;
199+
200+ let mut num_add = 0 ;
201+ let mut num_mul = 0 ;
202+
203+ for node in & dag {
204+ match node. op {
205+ 0 => ( ) , // skip wit index
206+ 1 => ( ) , // skip scalar index
207+ 2 => {
208+ num_add += 1 ;
209+ }
210+ 3 => {
211+ num_mul += 1 ;
212+ }
213+ op => panic ! ( "unknown op {op}" ) ,
214+ }
215+ }
216+
217+ tracing:: debug!(
218+ "layer name {} dag got num_add {num_add} num_mul {num_mul} max_degree {max_degree} \
219+ max_dag_depth {max_dag_depth} num_scalar {} final_out_index {final_out_index}",
220+ self . name,
221+ coeffs. len( ) ,
222+ ) ;
223+ ( dag, coeffs, stack_top, max_dag_depth as usize , zero_expr_degree)
224+ } )
225+ } ;
226+
225227 exit_span ! ( span) ;
226228 }
227229
0 commit comments