1414//! - Start/accept state validation
1515
1616use serde:: Serialize ;
17- use std:: collections:: BTreeSet ;
1817
1918use crate :: nfa:: NFAGraph ;
2019use crate :: nfa:: error:: { NFAError , NFAResult } ;
@@ -35,10 +34,10 @@ pub struct CircomInputs {
3534 pub next_states : Vec < usize > ,
3635 #[ serde( skip_serializing_if = "Option::is_none" ) ]
3736 #[ serde( rename = "captureGroupIds" ) ]
38- pub capture_group_ids : Option < Vec < usize > > ,
37+ pub capture_group_ids : Option < Vec < Vec < usize > > > ,
3938 #[ serde( skip_serializing_if = "Option::is_none" ) ]
4039 #[ serde( rename = "captureGroupStarts" ) ]
41- pub capture_group_starts : Option < Vec < u8 > > ,
40+ pub capture_group_starts : Option < Vec < Vec < u8 > > > ,
4241 #[ serde( skip_serializing_if = "Option::is_none" ) ]
4342 #[ serde( rename = "captureGroupStartIndices" ) ]
4443 pub capture_group_start_indices : Option < Vec < usize > > ,
@@ -85,18 +84,12 @@ pub fn generate_circom_code(
8584
8685 let ( start_states, accept_states, transitions) = generate_circuit_data ( nfa) ?;
8786
88- // Validate capture groups
89- let capture_group_set: BTreeSet < _ > = transitions
90- . iter ( )
91- . filter_map ( |( _, _, _, _, cap) | cap. map ( |( id, _) | id) )
92- . collect ( ) ;
93-
94- if !capture_group_set. is_empty ( ) {
87+ if nfa. num_capture_groups > 0 {
9588 if let Some ( max_bytes) = max_substring_bytes. as_ref ( ) {
96- if max_bytes. len ( ) < capture_group_set . len ( ) {
89+ if max_bytes. len ( ) != nfa . num_capture_groups {
9790 return Err ( NFAError :: InvalidCapture ( format ! (
9891 "Insufficient max_substring_bytes: need {} but got {}" ,
99- capture_group_set . len ( ) ,
92+ nfa . num_capture_groups ,
10093 max_bytes. len( )
10194 ) ) ) ;
10295 }
@@ -114,8 +107,6 @@ pub fn generate_circom_code(
114107 }
115108 }
116109
117- let has_capture_groups = !capture_group_set. is_empty ( ) ;
118-
119110 let mut code = String :: new ( ) ;
120111
121112 code. push_str ( "pragma circom 2.1.5;\n \n " ) ;
@@ -143,9 +134,21 @@ pub fn generate_circom_code(
143134 code. push_str ( " signal input nextStates[maxMatchBytes];\n " ) ;
144135
145136 // Only add capture group signals if needed
146- if has_capture_groups {
147- code. push_str ( " signal input captureGroupIds[maxMatchBytes];\n " ) ;
148- code. push_str ( " signal input captureGroupStarts[maxMatchBytes];\n \n " ) ;
137+ if nfa. num_capture_groups > 0 {
138+ for i in 0 ..nfa. num_capture_groups {
139+ code. push_str (
140+ format ! ( " signal input captureGroup{}Id[maxMatchBytes];\n " , i + 1 ) . as_str ( ) ,
141+ ) ;
142+ }
143+ for i in 0 ..nfa. num_capture_groups {
144+ code. push_str (
145+ format ! (
146+ " signal input captureGroup{}Start[maxMatchBytes];\n " ,
147+ i + 1
148+ )
149+ . as_str ( ) ,
150+ ) ;
151+ }
149152 }
150153
151154 code. push_str ( " signal output isValid;\n \n " ) ;
@@ -234,63 +237,103 @@ pub fn generate_circom_code(
234237 code. push_str ( " isTransitionLinked[i] === isWithinPathLengthMinusOne[i];\n " ) ;
235238 code. push_str ( " }\n \n " ) ;
236239
237- if has_capture_groups {
240+ if nfa. num_capture_groups > 0 {
241+ // Prepare strings for input signal arrays, used in each transition's Circom call
242+ let input_signal_cg_ids_list_str = ( 1 ..=nfa. num_capture_groups )
243+ . map ( |k| format ! ( "captureGroup{}Ids[i]" , k) )
244+ . collect :: < Vec < _ > > ( )
245+ . join ( ", " ) ;
246+ let input_signal_cg_starts_list_str = ( 1 ..=nfa. num_capture_groups )
247+ . map ( |k| format ! ( "captureGroup{}Starts[i]" , k) )
248+ . collect :: < Vec < _ > > ( )
249+ . join ( ", " ) ;
250+
238251 for ( transition_idx, ( curr_state, start, end, next_state, capture_info) ) in
239252 transitions. iter ( ) . enumerate ( )
240253 {
241- let ( capture_group_id, capture_group_start) = match capture_info {
242- Some ( capture_info) => ( capture_info. 0 , capture_info. 1 as u8 ) ,
243- None => ( 0 , 0 ) ,
254+ // These vectors store the properties of *this specific transition*
255+ // regarding which capture groups it affects and how.
256+ let mut transition_prop_cg_ids = vec ! [ 0 ; nfa. num_capture_groups] ;
257+ let mut transition_prop_cg_starts = vec ! [ 0 ; nfa. num_capture_groups] ;
258+
259+ let capture_details_for_comment = capture_info
260+ . as_ref ( )
261+ . map ( |infos| {
262+ infos
263+ . iter ( )
264+ . map ( |( id, is_start_bool) | {
265+ if * id > 0 && * id <= nfa. num_capture_groups {
266+ transition_prop_cg_ids[ * id - 1 ] = * id;
267+ transition_prop_cg_starts[ * id - 1 ] = * is_start_bool as u8 ;
268+ }
269+ format ! ( "({}, {})" , id, * is_start_bool as u8 )
270+ } )
271+ . collect :: < Vec < String > > ( )
272+ . join ( ", " )
273+ } )
274+ . unwrap_or_else ( || "" . to_string ( ) ) ;
275+
276+ let capture_comment_segment = if capture_details_for_comment. is_empty ( ) {
277+ "Capture Group: []" . to_string ( )
278+ } else {
279+ format ! ( "Capture Group:[ {}]" , capture_details_for_comment)
244280 } ;
245281
282+ // String representation of this transition's capture group properties
283+ let transition_prop_cg_ids_str = transition_prop_cg_ids
284+ . iter ( )
285+ . map ( ToString :: to_string)
286+ . collect :: < Vec < _ > > ( )
287+ . join ( ", " ) ;
288+ let transition_prop_cg_starts_str = transition_prop_cg_starts
289+ . iter ( )
290+ . map ( ToString :: to_string)
291+ . collect :: < Vec < _ > > ( )
292+ . join ( ", " ) ;
293+
246294 if start == end {
247295 code. push_str (
248296 format ! (
249- " // Transition {}: {} -[{}]-> {} | Capture Group: ({}, {})\n " ,
250- transition_idx,
251- curr_state,
252- start,
253- next_state,
254- capture_group_id,
255- capture_group_start
297+ " // Transition {}: {} -[{}]-> {} | {}\n " ,
298+ transition_idx, curr_state, start, next_state, capture_comment_segment
256299 )
257300 . as_str ( ) ,
258301 ) ;
259302 code. push_str (
260303 format ! (
261- " isValidTransition[{}][i] <== CheckByteTransitionWithCapture()({}, {}, {}, {}, {} , currStates[i], nextStates[i], haystack[i], captureGroupIds[i ], captureGroupStarts[i ]);\n " ,
304+ " isValidTransition[{}][i] <== CheckByteTransitionWithCapture({} )({}, {}, {}, [{}], [{}] , currStates[i], nextStates[i], haystack[i], [{} ], [{} ]);\n " ,
262305 transition_idx,
306+ nfa. num_capture_groups,
263307 curr_state,
264308 next_state,
265309 start,
266- capture_group_id,
267- capture_group_start
310+ transition_prop_cg_ids_str,
311+ transition_prop_cg_starts_str,
312+ input_signal_cg_ids_list_str, // Array of input signals
313+ input_signal_cg_starts_list_str // Array of input signals
268314 ) . as_str ( )
269315 ) ;
270316 } else {
271317 code. push_str (
272318 format ! (
273- " // Transition {}: {} -[{}-{}]-> {} | Capture Group: ({}, {})\n " ,
274- transition_idx,
275- curr_state,
276- start,
277- end,
278- next_state,
279- capture_group_id,
280- capture_group_start
319+ " // Transition {}: {} -[{}-{}]-> {} | {}\n " ,
320+ transition_idx, curr_state, start, end, next_state, capture_comment_segment
281321 )
282322 . as_str ( ) ,
283323 ) ;
284324 code. push_str (
285325 format ! (
286- " isValidTransition[{}][i] <== CheckByteRangeTransitionWithCapture()({}, {}, {}, {}, {}, {} , currStates[i], nextStates[i], haystack[i], captureGroupIds[i ], captureGroupStarts[i ]);\n " ,
326+ " isValidTransition[{}][i] <== CheckByteRangeTransitionWithCapture({} )({}, {}, {}, {}, [{}], [{}] , currStates[i], nextStates[i], haystack[i], [{} ], [{} ]);\n " ,
287327 transition_idx,
328+ nfa. num_capture_groups,
288329 curr_state,
289330 next_state,
290331 start,
291332 end,
292- capture_group_id,
293- capture_group_start
333+ transition_prop_cg_ids_str,
334+ transition_prop_cg_starts_str,
335+ input_signal_cg_ids_list_str, // Array of input signals
336+ input_signal_cg_starts_list_str // Array of input signals
294337 ) . as_str ( )
295338 ) ;
296339 }
@@ -375,15 +418,15 @@ pub fn generate_circom_code(
375418 code. push_str ( " }\n \n " ) ;
376419 code. push_str ( " isValid <== isValidRegex[maxMatchBytes-1];\n \n " ) ;
377420
378- if has_capture_groups {
421+ if nfa . num_capture_groups > 0 {
379422 code. push_str (
380423 format ! (
381424 " signal input captureGroupStartIndices[{}];\n \n " ,
382- capture_group_set . len ( )
425+ nfa . num_capture_groups
383426 )
384427 . as_str ( ) ,
385428 ) ;
386- for capture_group_id in capture_group_set {
429+ for capture_group_id in 1 ..=nfa . num_capture_groups {
387430 let max_substring_bytes =
388431 if let Some ( max_substring_bytes) = max_substring_bytes. as_ref ( ) {
389432 max_substring_bytes[ capture_group_id - 1 ]
@@ -397,12 +440,14 @@ pub fn generate_circom_code(
397440 code. push_str ( format ! ( " // Capture Group {}\n " , capture_group_id) . as_str ( ) ) ;
398441 code. push_str (
399442 format ! (
400- " signal output capture{}[{}] <== CaptureSubstring(maxMatchBytes, {}, {})(captureGroupStartIndices[{}], haystack, captureGroupIds, captureGroupStarts );\n " ,
443+ " signal output capture{}[{}] <== CaptureSubstring(maxMatchBytes, {}, {})(captureGroupStartIndices[{}], haystack, captureGroup{}Id, captureGroup{}Start );\n " ,
401444 capture_group_id,
402445 max_substring_bytes,
403446 max_substring_bytes,
404447 capture_group_id,
405- capture_group_id - 1
448+ capture_group_id,
449+ capture_group_id,
450+ capture_group_id
406451 ) . as_str ( )
407452 ) ;
408453 }
0 commit comments