@@ -234,61 +234,106 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
234234 layer. n_instance,
235235 ) ;
236236
237- let (
238- dag,
239- instance_scalar_expr,
240- challenges_expr,
241- constant_expr,
242- stack_top,
243- ( max_degree, max_dag_depth) ,
244- ) = layer. main_sumcheck_expression_dag . as_ref ( ) . unwrap ( ) ;
237+ let ( proof_gpu, evals_gpu, challenges_gpu) = if layer. exprs . len ( ) > 200 {
238+ let (
239+ dag,
240+ instance_scalar_expr,
241+ challenges_expr,
242+ constant_expr,
243+ stack_top,
244+ ( max_degree, max_dag_depth) ,
245+ ) = layer. main_sumcheck_expression_dag . as_ref ( ) . unwrap ( ) ;
245246
246- // format: pub_io ++ challenge ++ constant
247- let term_coefficients = instance_scalar_expr
248- . iter ( )
249- . map ( |Instance ( id) | pub_io_evals[ * id] )
250- . chain (
251- challenges_expr
252- . iter ( )
253- . map ( |c| {
254- eval_by_expr_constant (
255- & pub_io_evals. iter ( ) . map ( |v| Either :: Right ( * v) ) . collect_vec ( ) ,
256- & main_sumcheck_challenges,
257- c,
258- )
259- } )
260- . chain ( constant_expr. iter ( ) . copied ( ) )
261- . map ( |either_v| match either_v {
262- Either :: Left ( base_field_val) => E :: from ( base_field_val) ,
263- Either :: Right ( ext_field_val) => ext_field_val,
264- } ) ,
265- )
266- . collect_vec ( ) ;
247+ // format: pub_io ++ challenge ++ constant
248+ let term_coefficients = instance_scalar_expr
249+ . iter ( )
250+ . map ( |Instance ( id) | pub_io_evals[ * id] )
251+ . chain (
252+ challenges_expr
253+ . iter ( )
254+ . map ( |c| {
255+ eval_by_expr_constant (
256+ & pub_io_evals. iter ( ) . map ( |v| Either :: Right ( * v) ) . collect_vec ( ) ,
257+ & main_sumcheck_challenges,
258+ c,
259+ )
260+ } )
261+ . chain ( constant_expr. iter ( ) . copied ( ) )
262+ . map ( |either_v| match either_v {
263+ Either :: Left ( base_field_val) => E :: from ( base_field_val) ,
264+ Either :: Right ( ext_field_val) => ext_field_val,
265+ } ) ,
266+ )
267+ . collect_vec ( ) ;
267268
268- let max_num_var = max_num_variables;
269+ let max_num_var = max_num_variables;
270+
271+ // Convert types for GPU function Call
272+ let basic_tr: & mut BasicTranscript < BB31Ext > =
273+ unsafe { & mut * ( transcript as * mut _ as * mut BasicTranscript < BB31Ext > ) } ;
274+ let term_coefficients_gl64: Vec < BB31Ext > =
275+ unsafe { std:: mem:: transmute ( term_coefficients) } ;
276+ let all_witins_gpu_gl64: Vec < & MultilinearExtensionGpu < BB31Ext > > =
277+ unsafe { std:: mem:: transmute ( all_witins_gpu) } ;
278+ let all_witins_gpu_type_gl64 =
279+ all_witins_gpu_gl64. iter ( ) . map ( |mle| & mle. mle ) . collect_vec ( ) ;
280+ cuda_hal
281+ . sumcheck
282+ . prove_generic_sumcheck_gpu_v2 (
283+ & cuda_hal,
284+ dag,
285+ * max_dag_depth,
286+ all_witins_gpu_type_gl64,
287+ & term_coefficients_gl64,
288+ max_num_var,
289+ * max_degree,
290+ * stack_top,
291+ basic_tr,
292+ )
293+ . unwrap ( )
294+ } else {
295+ // Calculate max_num_var and max_degree from the extracted relationships
296+ let ( term_coefficients, mle_indices_per_term, mle_size_info) =
297+ extract_mle_relationships_from_monomial_terms (
298+ & layer
299+ . main_sumcheck_expression_monomial_terms
300+ . clone ( )
301+ . unwrap ( ) ,
302+ & all_witins_gpu,
303+ & pub_io_evals. iter ( ) . map ( |v| Either :: Right ( * v) ) . collect_vec ( ) ,
304+ & main_sumcheck_challenges,
305+ ) ;
306+ let max_num_var = max_num_variables;
307+ let max_degree = mle_indices_per_term
308+ . iter ( )
309+ . map ( |indices| indices. len ( ) )
310+ . max ( )
311+ . unwrap_or ( 0 ) ;
312+
313+ // Convert types for GPU function Call
314+ let basic_tr: & mut BasicTranscript < BB31Ext > =
315+ unsafe { & mut * ( transcript as * mut _ as * mut BasicTranscript < BB31Ext > ) } ;
316+ let term_coefficients_gl64: Vec < BB31Ext > =
317+ unsafe { std:: mem:: transmute ( term_coefficients) } ;
318+ let all_witins_gpu_gl64: Vec < & MultilinearExtensionGpu < BB31Ext > > =
319+ unsafe { std:: mem:: transmute ( all_witins_gpu) } ;
320+ let all_witins_gpu_type_gl64 =
321+ all_witins_gpu_gl64. iter ( ) . map ( |mle| & mle. mle ) . collect_vec ( ) ;
322+ cuda_hal
323+ . sumcheck
324+ . prove_generic_sumcheck_gpu (
325+ & cuda_hal,
326+ all_witins_gpu_type_gl64,
327+ & mle_size_info,
328+ & term_coefficients_gl64,
329+ & mle_indices_per_term,
330+ max_num_var,
331+ max_degree,
332+ basic_tr,
333+ )
334+ . unwrap ( )
335+ } ;
269336
270- // Convert types for GPU function Call
271- let basic_tr: & mut BasicTranscript < BB31Ext > =
272- unsafe { & mut * ( transcript as * mut _ as * mut BasicTranscript < BB31Ext > ) } ;
273- let term_coefficients_gl64: Vec < BB31Ext > =
274- unsafe { std:: mem:: transmute ( term_coefficients) } ;
275- let all_witins_gpu_gl64: Vec < & MultilinearExtensionGpu < BB31Ext > > =
276- unsafe { std:: mem:: transmute ( all_witins_gpu) } ;
277- let all_witins_gpu_type_gl64 = all_witins_gpu_gl64. iter ( ) . map ( |mle| & mle. mle ) . collect_vec ( ) ;
278- let ( proof_gpu, evals_gpu, challenges_gpu) = cuda_hal
279- . sumcheck
280- . prove_generic_sumcheck_gpu_v2 (
281- & cuda_hal,
282- dag,
283- * max_dag_depth,
284- all_witins_gpu_type_gl64,
285- & term_coefficients_gl64,
286- max_num_var,
287- * max_degree,
288- * stack_top,
289- basic_tr,
290- )
291- . unwrap ( ) ;
292337 let evals_gpu = evals_gpu. into_iter ( ) . flatten ( ) . collect_vec ( ) ;
293338 let row_challenges = challenges_gpu. iter ( ) . map ( |c| c. elements ) . collect_vec ( ) ;
294339
0 commit comments