|
| 1 | +/* |
| 2 | + * The `ColsRef` procedural macro is used in constraint generation to create column structs that |
| 3 | + * have dynamic sizes. |
| 4 | + * |
| 5 | + * Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the |
| 6 | + * same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384). |
| 7 | + * See the [SHA-2 VM extension](openvm/extensions/sha2/circuit/src/sha2_chip/air.rs) for an |
| 8 | + * example of how to use the `ColsRef` macro to reuse constraint generation code over multiple |
| 9 | + * circuits. |
| 10 | + * |
| 11 | + * This macro can also be used in other situations where we want to derive Borrow<T> for &[u8], |
| 12 | + * for some complicated struct T. |
| 13 | + */ |
| 14 | +mod utils; |
| 15 | + |
| 16 | +use utils::*; |
| 17 | + |
1 | 18 | extern crate proc_macro; |
2 | 19 |
|
3 | 20 | use itertools::Itertools; |
@@ -169,7 +186,8 @@ fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_mac |
169 | 186 | } |
170 | 187 | } |
171 | 188 |
|
172 | | - // returns number of cells in the struct (where each cell has type T) |
| 189 | + // Returns number of cells in the struct (where each cell has type T). |
| 190 | + // This method should only be called if the struct has no primitive types (i.e. for columns structs). |
173 | 191 | pub const fn width<C: #config>() -> usize { |
174 | 192 | 0 #( + #length_exprs )* |
175 | 193 | } |
@@ -227,7 +245,7 @@ fn make_from_mut(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_m |
227 | 245 | &other.#ident |
228 | 246 | } |
229 | 247 | } else { |
230 | | - panic!("Unsupported field type: {:?}", f.ty); |
| 248 | + panic!("Unsupported field type (in make_from_mut): {:?}", f.ty); |
231 | 249 | } |
232 | 250 | }) |
233 | 251 | .collect_vec(); |
@@ -346,8 +364,28 @@ fn get_const_cols_ref_fields( |
346 | 364 | #slice_var |
347 | 365 | }, |
348 | 366 | } |
| 367 | + } else if is_primitive_type(&elem_type) { |
| 368 | + FieldInfo { |
| 369 | + ty: parse_quote! { |
| 370 | + &'a #elem_type |
| 371 | + }, |
| 372 | + // Columns structs won't ever have primitive types, but this macro can be used on |
| 373 | + // other structs as well, to make it easy to borrow a struct from &[u8]. |
| 374 | + // We just set length = 0 knowing that calling the width() method is undefined if |
| 375 | + // the struct has a primitive type. |
| 376 | + length_expr: quote! { |
| 377 | + 0 |
| 378 | + }, |
| 379 | + prepare_subslice: quote! { |
| 380 | + let (#slice_var, slice) = slice.split_at(std::mem::size_of::<#elem_type>() #(* #dim_exprs)*); |
| 381 | + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); |
| 382 | + }, |
| 383 | + initializer: quote! { |
| 384 | + #slice_var |
| 385 | + }, |
| 386 | + } |
349 | 387 | } else { |
350 | | - panic!("Unsupported field type: {:?}", f.ty); |
| 388 | + panic!("Unsupported field type (in get_const_cols_ref_fields): {:?}", f.ty); |
351 | 389 | } |
352 | 390 | } else if derives_aligned_borrow { |
353 | 391 | // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) |
@@ -405,7 +443,7 @@ fn get_const_cols_ref_fields( |
405 | 443 | }, |
406 | 444 | } |
407 | 445 | } else { |
408 | | - panic!("Unsupported field type: {:?}", f.ty); |
| 446 | + panic!("Unsupported field type (in get_mut_cols_ref_fields): {:?}", f.ty); |
409 | 447 | } |
410 | 448 | } |
411 | 449 |
|
@@ -485,8 +523,28 @@ fn get_mut_cols_ref_fields( |
485 | 523 | #slice_var |
486 | 524 | }, |
487 | 525 | } |
| 526 | + } else if is_primitive_type(&elem_type) { |
| 527 | + FieldInfo { |
| 528 | + ty: parse_quote! { |
| 529 | + &'a mut #elem_type |
| 530 | + }, |
| 531 | + // Columns structs won't ever have primitive types, but this macro can be used on |
| 532 | + // other structs as well, to make it easy to borrow a struct from &[u8]. |
| 533 | + // We just set length = 0 knowing that calling the width() method is undefined if |
| 534 | + // the struct has a primitive type. |
| 535 | + length_expr: quote! { |
| 536 | + 0 |
| 537 | + }, |
| 538 | + prepare_subslice: quote! { |
| 539 | + let (#slice_var, slice) = slice.split_at_mut(std::mem::size_of::<#elem_type>() #(* #dim_exprs)*); |
| 540 | + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); |
| 541 | + }, |
| 542 | + initializer: quote! { |
| 543 | + #slice_var |
| 544 | + }, |
| 545 | + } |
488 | 546 | } else { |
489 | | - panic!("Unsupported field type: {:?}", f.ty); |
| 547 | + panic!("Unsupported field type (in get_mut_cols_ref_fields): {:?}", f.ty); |
490 | 548 | } |
491 | 549 | } else if derives_aligned_borrow { |
492 | 550 | // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) |
@@ -544,7 +602,7 @@ fn get_mut_cols_ref_fields( |
544 | 602 | }, |
545 | 603 | } |
546 | 604 | } else { |
547 | | - panic!("Unsupported field type: {:?}", f.ty); |
| 605 | + panic!("Unsupported field type (in get_mut_cols_ref_fields): {:?}", f.ty); |
548 | 606 | } |
549 | 607 | } |
550 | 608 |
|
@@ -637,61 +695,3 @@ fn is_generic_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> bool { |
637 | 695 | false |
638 | 696 | } |
639 | 697 | } |
640 | | - |
641 | | -// Type of array dimension |
642 | | -enum Dimension { |
643 | | - ConstGeneric(syn::Expr), |
644 | | - Other(syn::Expr), |
645 | | -} |
646 | | - |
647 | | -// Describes a nested array |
648 | | -struct ArrayInfo { |
649 | | - dims: Vec<Dimension>, |
650 | | - elem_type: syn::Type, |
651 | | -} |
652 | | - |
653 | | -fn get_array_info(ty: &syn::Type, const_generics: &[&syn::Ident]) -> ArrayInfo { |
654 | | - let dims = get_dims(ty, const_generics); |
655 | | - let elem_type = get_elem_type(ty); |
656 | | - ArrayInfo { dims, elem_type } |
657 | | -} |
658 | | - |
659 | | -fn get_elem_type(ty: &syn::Type) -> syn::Type { |
660 | | - match ty { |
661 | | - syn::Type::Array(array) => get_elem_type(array.elem.as_ref()), |
662 | | - syn::Type::Path(_) => ty.clone(), |
663 | | - _ => panic!("Unsupported type: {:?}", ty), |
664 | | - } |
665 | | -} |
666 | | - |
667 | | -// Get a vector of the dimensions of the array |
668 | | -// Each dimension is either a constant generic or a literal integer value |
669 | | -fn get_dims(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec<Dimension> { |
670 | | - get_dims_impl(ty, const_generics) |
671 | | - .into_iter() |
672 | | - .rev() |
673 | | - .collect() |
674 | | -} |
675 | | - |
676 | | -fn get_dims_impl(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec<Dimension> { |
677 | | - match ty { |
678 | | - syn::Type::Array(array) => { |
679 | | - let mut dims = get_dims_impl(array.elem.as_ref(), const_generics); |
680 | | - match &array.len { |
681 | | - syn::Expr::Path(syn::ExprPath { path, .. }) => { |
682 | | - let len_ident = path.get_ident(); |
683 | | - if len_ident.is_some() && const_generics.contains(&len_ident.unwrap()) { |
684 | | - dims.push(Dimension::ConstGeneric(array.len.clone())); |
685 | | - } else { |
686 | | - dims.push(Dimension::Other(array.len.clone())); |
687 | | - } |
688 | | - } |
689 | | - syn::Expr::Lit(expr_lit) => dims.push(Dimension::Other(expr_lit.clone().into())), |
690 | | - _ => panic!("Unsupported array length type"), |
691 | | - } |
692 | | - dims |
693 | | - } |
694 | | - syn::Type::Path(_) => Vec::new(), |
695 | | - _ => panic!("Unsupported field type"), |
696 | | - } |
697 | | -} |
0 commit comments