Skip to content

Commit 98f5d9e

Browse files
committed
dag version only on larger circuit
1 parent dcec311 commit 98f5d9e

File tree

2 files changed

+103
-60
lines changed

2 files changed

+103
-60
lines changed

gkr_iop/src/gkr/layer/gpu/mod.rs

Lines changed: 97 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -234,61 +234,106 @@ impl<E: ExtensionField, PCS: PolynomialCommitmentScheme<E>> ZerocheckLayerProver
234234
layer.n_instance,
235235
);
236236

237-
let (
238-
dag,
239-
instance_scalar_expr,
240-
challenges_expr,
241-
constant_expr,
242-
stack_top,
243-
(max_degree, max_dag_depth),
244-
) = layer.main_sumcheck_expression_dag.as_ref().unwrap();
237+
let (proof_gpu, evals_gpu, challenges_gpu) = if layer.exprs.len() > 200 {
238+
let (
239+
dag,
240+
instance_scalar_expr,
241+
challenges_expr,
242+
constant_expr,
243+
stack_top,
244+
(max_degree, max_dag_depth),
245+
) = layer.main_sumcheck_expression_dag.as_ref().unwrap();
245246

246-
// format: pub_io ++ challenge ++ constant
247-
let term_coefficients = instance_scalar_expr
248-
.iter()
249-
.map(|Instance(id)| pub_io_evals[*id])
250-
.chain(
251-
challenges_expr
252-
.iter()
253-
.map(|c| {
254-
eval_by_expr_constant(
255-
&pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec(),
256-
&main_sumcheck_challenges,
257-
c,
258-
)
259-
})
260-
.chain(constant_expr.iter().copied())
261-
.map(|either_v| match either_v {
262-
Either::Left(base_field_val) => E::from(base_field_val),
263-
Either::Right(ext_field_val) => ext_field_val,
264-
}),
265-
)
266-
.collect_vec();
247+
// format: pub_io ++ challenge ++ constant
248+
let term_coefficients = instance_scalar_expr
249+
.iter()
250+
.map(|Instance(id)| pub_io_evals[*id])
251+
.chain(
252+
challenges_expr
253+
.iter()
254+
.map(|c| {
255+
eval_by_expr_constant(
256+
&pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec(),
257+
&main_sumcheck_challenges,
258+
c,
259+
)
260+
})
261+
.chain(constant_expr.iter().copied())
262+
.map(|either_v| match either_v {
263+
Either::Left(base_field_val) => E::from(base_field_val),
264+
Either::Right(ext_field_val) => ext_field_val,
265+
}),
266+
)
267+
.collect_vec();
267268

268-
let max_num_var = max_num_variables;
269+
let max_num_var = max_num_variables;
270+
271+
// Convert types for GPU function Call
272+
let basic_tr: &mut BasicTranscript<BB31Ext> =
273+
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<BB31Ext>) };
274+
let term_coefficients_gl64: Vec<BB31Ext> =
275+
unsafe { std::mem::transmute(term_coefficients) };
276+
let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<BB31Ext>> =
277+
unsafe { std::mem::transmute(all_witins_gpu) };
278+
let all_witins_gpu_type_gl64 =
279+
all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec();
280+
cuda_hal
281+
.sumcheck
282+
.prove_generic_sumcheck_gpu_v2(
283+
&cuda_hal,
284+
dag,
285+
*max_dag_depth,
286+
all_witins_gpu_type_gl64,
287+
&term_coefficients_gl64,
288+
max_num_var,
289+
*max_degree,
290+
*stack_top,
291+
basic_tr,
292+
)
293+
.unwrap()
294+
} else {
295+
// Calculate max_num_var and max_degree from the extracted relationships
296+
let (term_coefficients, mle_indices_per_term, mle_size_info) =
297+
extract_mle_relationships_from_monomial_terms(
298+
&layer
299+
.main_sumcheck_expression_monomial_terms
300+
.clone()
301+
.unwrap(),
302+
&all_witins_gpu,
303+
&pub_io_evals.iter().map(|v| Either::Right(*v)).collect_vec(),
304+
&main_sumcheck_challenges,
305+
);
306+
let max_num_var = max_num_variables;
307+
let max_degree = mle_indices_per_term
308+
.iter()
309+
.map(|indices| indices.len())
310+
.max()
311+
.unwrap_or(0);
312+
313+
// Convert types for GPU function Call
314+
let basic_tr: &mut BasicTranscript<BB31Ext> =
315+
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<BB31Ext>) };
316+
let term_coefficients_gl64: Vec<BB31Ext> =
317+
unsafe { std::mem::transmute(term_coefficients) };
318+
let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<BB31Ext>> =
319+
unsafe { std::mem::transmute(all_witins_gpu) };
320+
let all_witins_gpu_type_gl64 =
321+
all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec();
322+
cuda_hal
323+
.sumcheck
324+
.prove_generic_sumcheck_gpu(
325+
&cuda_hal,
326+
all_witins_gpu_type_gl64,
327+
&mle_size_info,
328+
&term_coefficients_gl64,
329+
&mle_indices_per_term,
330+
max_num_var,
331+
max_degree,
332+
basic_tr,
333+
)
334+
.unwrap()
335+
};
269336

270-
// Convert types for GPU function Call
271-
let basic_tr: &mut BasicTranscript<BB31Ext> =
272-
unsafe { &mut *(transcript as *mut _ as *mut BasicTranscript<BB31Ext>) };
273-
let term_coefficients_gl64: Vec<BB31Ext> =
274-
unsafe { std::mem::transmute(term_coefficients) };
275-
let all_witins_gpu_gl64: Vec<&MultilinearExtensionGpu<BB31Ext>> =
276-
unsafe { std::mem::transmute(all_witins_gpu) };
277-
let all_witins_gpu_type_gl64 = all_witins_gpu_gl64.iter().map(|mle| &mle.mle).collect_vec();
278-
let (proof_gpu, evals_gpu, challenges_gpu) = cuda_hal
279-
.sumcheck
280-
.prove_generic_sumcheck_gpu_v2(
281-
&cuda_hal,
282-
dag,
283-
*max_dag_depth,
284-
all_witins_gpu_type_gl64,
285-
&term_coefficients_gl64,
286-
max_num_var,
287-
*max_degree,
288-
*stack_top,
289-
basic_tr,
290-
)
291-
.unwrap();
292337
let evals_gpu = evals_gpu.into_iter().flatten().collect_vec();
293338
let row_challenges = challenges_gpu.iter().map(|c| c.elements).collect_vec();
294339

gkr_iop/src/gkr/layer/zerocheck_layer.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,22 +178,20 @@ impl<E: ExtensionField> ZerocheckLayer<E> for Layer<E> {
178178
(max_degree, max_dag_depth),
179179
) = expr_compression_to_dag(self.main_sumcheck_expression.as_ref().unwrap());
180180

181-
let mut traverse_dag_id = 0;
182181
let mut num_add = 0;
183182
let mut num_mul = 0;
184-
while traverse_dag_id < dag.len() {
185-
match dag[traverse_dag_id].op {
186-
0 => traverse_dag_id += 2, // skip wit index
187-
1 => traverse_dag_id += 2, // skip scalar index
183+
184+
for node in &dag {
185+
match node.op {
186+
0 => (), // skip wit index
187+
1 => (), // skip scalar index
188188
2 => {
189189
num_add += 1;
190-
traverse_dag_id += 1;
191190
}
192191
3 => {
193192
num_mul += 1;
194-
traverse_dag_id += 1;
195193
}
196-
_ => unreachable!(),
194+
op => panic!("unknown op {op}"),
197195
}
198196
}
199197

0 commit comments

Comments
 (0)