Skip to content

Commit 03f286c

Browse files
fix: fixed codegen
1 parent d04c6cf commit 03f286c

File tree

4 files changed

+308
-181
lines changed

4 files changed

+308
-181
lines changed

compiler/src/nfa/codegen/circom.rs

Lines changed: 92 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
//! - Start/accept state validation
1515
1616
use serde::Serialize;
17-
use std::collections::BTreeSet;
1817

1918
use crate::nfa::NFAGraph;
2019
use crate::nfa::error::{NFAError, NFAResult};
@@ -35,10 +34,10 @@ pub struct CircomInputs {
3534
pub next_states: Vec<usize>,
3635
#[serde(skip_serializing_if = "Option::is_none")]
3736
#[serde(rename = "captureGroupIds")]
38-
pub capture_group_ids: Option<Vec<usize>>,
37+
pub capture_group_ids: Option<Vec<Vec<usize>>>,
3938
#[serde(skip_serializing_if = "Option::is_none")]
4039
#[serde(rename = "captureGroupStarts")]
41-
pub capture_group_starts: Option<Vec<u8>>,
40+
pub capture_group_starts: Option<Vec<Vec<u8>>>,
4241
#[serde(skip_serializing_if = "Option::is_none")]
4342
#[serde(rename = "captureGroupStartIndices")]
4443
pub capture_group_start_indices: Option<Vec<usize>>,
@@ -85,18 +84,12 @@ pub fn generate_circom_code(
8584

8685
let (start_states, accept_states, transitions) = generate_circuit_data(nfa)?;
8786

88-
// Validate capture groups
89-
let capture_group_set: BTreeSet<_> = transitions
90-
.iter()
91-
.filter_map(|(_, _, _, _, cap)| cap.map(|(id, _)| id))
92-
.collect();
93-
94-
if !capture_group_set.is_empty() {
87+
if nfa.num_capture_groups > 0 {
9588
if let Some(max_bytes) = max_substring_bytes.as_ref() {
96-
if max_bytes.len() < capture_group_set.len() {
89+
if max_bytes.len() != nfa.num_capture_groups {
9790
return Err(NFAError::InvalidCapture(format!(
9891
"Insufficient max_substring_bytes: need {} but got {}",
99-
capture_group_set.len(),
92+
nfa.num_capture_groups,
10093
max_bytes.len()
10194
)));
10295
}
@@ -114,8 +107,6 @@ pub fn generate_circom_code(
114107
}
115108
}
116109

117-
let has_capture_groups = !capture_group_set.is_empty();
118-
119110
let mut code = String::new();
120111

121112
code.push_str("pragma circom 2.1.5;\n\n");
@@ -143,9 +134,21 @@ pub fn generate_circom_code(
143134
code.push_str(" signal input nextStates[maxMatchBytes];\n");
144135

145136
// Only add capture group signals if needed
146-
if has_capture_groups {
147-
code.push_str(" signal input captureGroupIds[maxMatchBytes];\n");
148-
code.push_str(" signal input captureGroupStarts[maxMatchBytes];\n\n");
137+
if nfa.num_capture_groups > 0 {
138+
for i in 0..nfa.num_capture_groups {
139+
code.push_str(
140+
format!(" signal input captureGroup{}Id[maxMatchBytes];\n", i + 1).as_str(),
141+
);
142+
}
143+
for i in 0..nfa.num_capture_groups {
144+
code.push_str(
145+
format!(
146+
" signal input captureGroup{}Start[maxMatchBytes];\n",
147+
i + 1
148+
)
149+
.as_str(),
150+
);
151+
}
149152
}
150153

151154
code.push_str(" signal output isValid;\n\n");
@@ -234,63 +237,103 @@ pub fn generate_circom_code(
234237
code.push_str(" isTransitionLinked[i] === isWithinPathLengthMinusOne[i];\n");
235238
code.push_str(" }\n\n");
236239

237-
if has_capture_groups {
240+
if nfa.num_capture_groups > 0 {
241+
// Prepare strings for input signal arrays, used in each transition's Circom call
242+
let input_signal_cg_ids_list_str = (1..=nfa.num_capture_groups)
243+
.map(|k| format!("captureGroup{}Ids[i]", k))
244+
.collect::<Vec<_>>()
245+
.join(", ");
246+
let input_signal_cg_starts_list_str = (1..=nfa.num_capture_groups)
247+
.map(|k| format!("captureGroup{}Starts[i]", k))
248+
.collect::<Vec<_>>()
249+
.join(", ");
250+
238251
for (transition_idx, (curr_state, start, end, next_state, capture_info)) in
239252
transitions.iter().enumerate()
240253
{
241-
let (capture_group_id, capture_group_start) = match capture_info {
242-
Some(capture_info) => (capture_info.0, capture_info.1 as u8),
243-
None => (0, 0),
254+
// These vectors store the properties of *this specific transition*
255+
// regarding which capture groups it affects and how.
256+
let mut transition_prop_cg_ids = vec![0; nfa.num_capture_groups];
257+
let mut transition_prop_cg_starts = vec![0; nfa.num_capture_groups];
258+
259+
let capture_details_for_comment = capture_info
260+
.as_ref()
261+
.map(|infos| {
262+
infos
263+
.iter()
264+
.map(|(id, is_start_bool)| {
265+
if *id > 0 && *id <= nfa.num_capture_groups {
266+
transition_prop_cg_ids[*id - 1] = *id;
267+
transition_prop_cg_starts[*id - 1] = *is_start_bool as u8;
268+
}
269+
format!("({}, {})", id, *is_start_bool as u8)
270+
})
271+
.collect::<Vec<String>>()
272+
.join(", ")
273+
})
274+
.unwrap_or_else(|| "".to_string());
275+
276+
let capture_comment_segment = if capture_details_for_comment.is_empty() {
277+
"Capture Group: []".to_string()
278+
} else {
279+
format!("Capture Group:[ {}]", capture_details_for_comment)
244280
};
245281

282+
// String representation of this transition's capture group properties
283+
let transition_prop_cg_ids_str = transition_prop_cg_ids
284+
.iter()
285+
.map(ToString::to_string)
286+
.collect::<Vec<_>>()
287+
.join(", ");
288+
let transition_prop_cg_starts_str = transition_prop_cg_starts
289+
.iter()
290+
.map(ToString::to_string)
291+
.collect::<Vec<_>>()
292+
.join(", ");
293+
246294
if start == end {
247295
code.push_str(
248296
format!(
249-
" // Transition {}: {} -[{}]-> {} | Capture Group: ({}, {})\n",
250-
transition_idx,
251-
curr_state,
252-
start,
253-
next_state,
254-
capture_group_id,
255-
capture_group_start
297+
" // Transition {}: {} -[{}]-> {} | {}\n",
298+
transition_idx, curr_state, start, next_state, capture_comment_segment
256299
)
257300
.as_str(),
258301
);
259302
code.push_str(
260303
format!(
261-
" isValidTransition[{}][i] <== CheckByteTransitionWithCapture()({}, {}, {}, {}, {}, currStates[i], nextStates[i], haystack[i], captureGroupIds[i], captureGroupStarts[i]);\n",
304+
" isValidTransition[{}][i] <== CheckByteTransitionWithCapture({})({}, {}, {}, [{}], [{}], currStates[i], nextStates[i], haystack[i], [{}], [{}]);\n",
262305
transition_idx,
306+
nfa.num_capture_groups,
263307
curr_state,
264308
next_state,
265309
start,
266-
capture_group_id,
267-
capture_group_start
310+
transition_prop_cg_ids_str,
311+
transition_prop_cg_starts_str,
312+
input_signal_cg_ids_list_str, // Array of input signals
313+
input_signal_cg_starts_list_str // Array of input signals
268314
).as_str()
269315
);
270316
} else {
271317
code.push_str(
272318
format!(
273-
" // Transition {}: {} -[{}-{}]-> {} | Capture Group: ({}, {})\n",
274-
transition_idx,
275-
curr_state,
276-
start,
277-
end,
278-
next_state,
279-
capture_group_id,
280-
capture_group_start
319+
" // Transition {}: {} -[{}-{}]-> {} | {}\n",
320+
transition_idx, curr_state, start, end, next_state, capture_comment_segment
281321
)
282322
.as_str(),
283323
);
284324
code.push_str(
285325
format!(
286-
" isValidTransition[{}][i] <== CheckByteRangeTransitionWithCapture()({}, {}, {}, {}, {}, {}, currStates[i], nextStates[i], haystack[i], captureGroupIds[i], captureGroupStarts[i]);\n",
326+
" isValidTransition[{}][i] <== CheckByteRangeTransitionWithCapture({})({}, {}, {}, {}, [{}], [{}], currStates[i], nextStates[i], haystack[i], [{}], [{}]);\n",
287327
transition_idx,
328+
nfa.num_capture_groups,
288329
curr_state,
289330
next_state,
290331
start,
291332
end,
292-
capture_group_id,
293-
capture_group_start
333+
transition_prop_cg_ids_str,
334+
transition_prop_cg_starts_str,
335+
input_signal_cg_ids_list_str, // Array of input signals
336+
input_signal_cg_starts_list_str // Array of input signals
294337
).as_str()
295338
);
296339
}
@@ -375,15 +418,15 @@ pub fn generate_circom_code(
375418
code.push_str(" }\n\n");
376419
code.push_str(" isValid <== isValidRegex[maxMatchBytes-1];\n\n");
377420

378-
if has_capture_groups {
421+
if nfa.num_capture_groups > 0 {
379422
code.push_str(
380423
format!(
381424
" signal input captureGroupStartIndices[{}];\n\n",
382-
capture_group_set.len()
425+
nfa.num_capture_groups
383426
)
384427
.as_str(),
385428
);
386-
for capture_group_id in capture_group_set {
429+
for capture_group_id in 1..=nfa.num_capture_groups {
387430
let max_substring_bytes =
388431
if let Some(max_substring_bytes) = max_substring_bytes.as_ref() {
389432
max_substring_bytes[capture_group_id - 1]
@@ -397,12 +440,14 @@ pub fn generate_circom_code(
397440
code.push_str(format!(" // Capture Group {}\n", capture_group_id).as_str());
398441
code.push_str(
399442
format!(
400-
" signal output capture{}[{}] <== CaptureSubstring(maxMatchBytes, {}, {})(captureGroupStartIndices[{}], haystack, captureGroupIds, captureGroupStarts);\n",
443+
" signal output capture{}[{}] <== CaptureSubstring(maxMatchBytes, {}, {})(captureGroupStartIndices[{}], haystack, captureGroup{}Id, captureGroup{}Start);\n",
401444
capture_group_id,
402445
max_substring_bytes,
403446
max_substring_bytes,
404447
capture_group_id,
405-
capture_group_id - 1
448+
capture_group_id,
449+
capture_group_id,
450+
capture_group_id
406451
).as_str()
407452
);
408453
}

compiler/src/nfa/codegen/mod.rs

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use circom::CircomInputs;
77
use noir::NoirInputs;
88
use regex_automata::meta::Regex;
99
use serde::Serialize;
10-
use std::collections::BTreeMap;
10+
use std::collections::{BTreeMap, BTreeSet};
1111

1212
use crate::{
1313
ProverInputs, ProvingFramework,
@@ -25,9 +25,9 @@ pub struct CircuitInputs {
2525
curr_states: Vec<usize>,
2626
next_states: Vec<usize>,
2727
#[serde(skip_serializing_if = "Option::is_none")]
28-
capture_group_ids: Option<Vec<usize>>,
28+
capture_group_ids: Option<Vec<Vec<usize>>>,
2929
#[serde(skip_serializing_if = "Option::is_none")]
30-
capture_group_starts: Option<Vec<u8>>,
30+
capture_group_starts: Option<Vec<Vec<u8>>>,
3131
#[serde(skip_serializing_if = "Option::is_none")]
3232
capture_group_start_indices: Option<Vec<usize>>,
3333
}
@@ -37,7 +37,7 @@ pub fn generate_circuit_data(
3737
) -> NFAResult<(
3838
Vec<usize>,
3939
Vec<usize>,
40-
Vec<(usize, u8, u8, usize, Option<(usize, bool)>)>,
40+
Vec<(usize, u8, u8, usize, Option<BTreeSet<(usize, bool)>>)>,
4141
)> {
4242
if nfa.start_states.is_empty() {
4343
return Err(NFAError::Verification("NFA has no start states".into()));
@@ -60,7 +60,8 @@ pub fn generate_circuit_data(
6060

6161
// Group and convert to ranges - use BTreeMap for deterministic ordering
6262
let mut range_transitions = Vec::new();
63-
let mut grouped: BTreeMap<(usize, usize, Option<(usize, bool)>), Vec<u8>> = BTreeMap::new();
63+
let mut grouped: BTreeMap<(usize, usize, Option<BTreeSet<(usize, bool)>>), Vec<u8>> =
64+
BTreeMap::new();
6465

6566
for (src, byte, dst, capture) in transitions {
6667
if src >= nfa.nodes.len() || dst >= nfa.nodes.len() {
@@ -75,7 +76,12 @@ pub fn generate_circuit_data(
7576
// Convert to ranges
7677
for ((src, dst, capture), mut bytes) in grouped {
7778
if bytes.is_empty() {
78-
continue;
79+
// This case should ideally not be reached if the grouping logic is correct
80+
// and transitions always have associated bytes.
81+
return Err(NFAError::InvalidTransition(format!(
82+
"Found an empty byte list for transition group (src: {}, dst: {}, capture: {:?}). This indicates an issue with NFA processing.",
83+
src, dst, capture
84+
)));
7985
}
8086

8187
bytes.sort_unstable();
@@ -84,12 +90,12 @@ pub fn generate_circuit_data(
8490

8591
for &byte in &bytes[1..] {
8692
if byte != prev + 1 {
87-
range_transitions.push((src, start, prev, dst, capture));
93+
range_transitions.push((src, start, prev, dst, capture.clone()));
8894
start = byte;
8995
}
9096
prev = byte;
9197
}
92-
range_transitions.push((src, start, prev, dst, capture));
98+
range_transitions.push((src, start, prev, dst, capture.clone()));
9399
}
94100

95101
Ok((start_states, accept_states, range_transitions))
@@ -144,14 +150,33 @@ pub fn generate_circuit_inputs(
144150
// Handle capture groups if they exist
145151
let (capture_group_ids, capture_group_starts, capture_group_start_indices) =
146152
if path.iter().any(|(_, _, _, c)| c.is_some()) {
147-
let mut ids = path
148-
.iter()
149-
.map(|(_, _, _, c)| c.map(|(id, _)| id).unwrap_or(0))
150-
.collect::<Vec<_>>();
151-
let mut starts = path
152-
.iter()
153-
.map(|(_, _, _, c)| c.map(|(_, start)| start as u8).unwrap_or(0))
154-
.collect::<Vec<_>>();
153+
// Initialize structures:
154+
// capture_group_ids[group_idx_0_based][step_idx]
155+
let mut capture_group_ids: Vec<Vec<usize>> =
156+
vec![vec![0; max_match_len]; nfa.num_capture_groups];
157+
// capture_group_starts[group_idx_0_based][step_idx]
158+
let mut capture_group_starts: Vec<Vec<u8>> =
159+
vec![vec![0; max_match_len]; nfa.num_capture_groups];
160+
161+
// Populate these based on the actual path traversal
162+
// path_len is the actual number of steps taken for the match
163+
for step_idx in 0..path_len {
164+
// path[step_idx].3 is the Option<BTreeSet<(usize group_id, bool is_start)>>
165+
if let Some(capture_set) = &path[step_idx].3 {
166+
for (group_id, is_start) in capture_set.iter() {
167+
// group_id is 1-based from the regex engine
168+
if *group_id > 0 && *group_id <= nfa.num_capture_groups {
169+
let group_vector_idx = *group_id - 1; // Convert to 0-based for vector access
170+
171+
// Record the group ID if active
172+
capture_group_ids[group_vector_idx][step_idx] = *group_id;
173+
// Record if it's a start or end/continuation
174+
capture_group_starts[group_vector_idx][step_idx] =
175+
if *is_start { 1 } else { 0 };
176+
}
177+
}
178+
}
179+
}
155180

156181
// Use regex_automata to get capture start indices
157182
let re = Regex::new(&nfa.regex).map_err(|e| {
@@ -165,11 +190,11 @@ pub fn generate_circuit_inputs(
165190
.map(|m| m.start)
166191
.collect();
167192

168-
// Pad arrays
169-
ids.resize(max_match_len, 0);
170-
starts.resize(max_match_len, 0);
171-
172-
(Some(ids), Some(starts), Some(start_indices))
193+
(
194+
Some(capture_group_ids),
195+
Some(capture_group_starts),
196+
Some(start_indices),
197+
)
173198
} else {
174199
(None, None, None)
175200
};

0 commit comments

Comments
 (0)