Skip to content

Commit 475b2ed

Browse files
WIP
1 parent a31776d commit 475b2ed

File tree

58 files changed

+4463
-1942
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+4463
-1942
lines changed

Cargo.lock

Lines changed: 83 additions & 79 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/circuits/primitives/derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ license.workspace = true
1212
proc-macro = true
1313

1414
[dependencies]
15-
syn = { version = "2.0", features = ["parsing", "extra-traits"] }
15+
syn = { version = "2.0", features = ["full", "parsing", "extra-traits"] }
1616
quote = "1.0"
1717
itertools = { workspace = true, default-features = true }
1818
proc-macro2 = "1.0"

crates/circuits/primitives/derive/src/cols_ref/mod.rs

Lines changed: 64 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,20 @@
1+
/*
2+
* The `ColsRef` procedural macro is used in constraint generation to create column structs that
3+
* have dynamic sizes.
4+
*
5+
* Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the
6+
* same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384).
7+
* See the [SHA-2 VM extension](openvm/extensions/sha2/circuit/src/sha2_chip/air.rs) for an
8+
* example of how to use the `ColsRef` macro to reuse constraint generation code over multiple
9+
* circuits.
10+
*
11+
* This macro can also be used in other situations where we want to derive Borrow<T> for &[u8],
12+
* for some complicated struct T.
13+
*/
14+
mod utils;
15+
16+
use utils::*;
17+
118
extern crate proc_macro;
219

320
use itertools::Itertools;
@@ -169,7 +186,8 @@ fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_mac
169186
}
170187
}
171188

172-
// returns number of cells in the struct (where each cell has type T)
189+
// Returns number of cells in the struct (where each cell has type T).
190+
// This method should only be called if the struct has no primitive types (i.e. for columns structs).
173191
pub const fn width<C: #config>() -> usize {
174192
0 #( + #length_exprs )*
175193
}
@@ -227,7 +245,7 @@ fn make_from_mut(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_m
227245
&other.#ident
228246
}
229247
} else {
230-
panic!("Unsupported field type: {:?}", f.ty);
248+
panic!("Unsupported field type (in make_from_mut): {:?}", f.ty);
231249
}
232250
})
233251
.collect_vec();
@@ -346,8 +364,28 @@ fn get_const_cols_ref_fields(
346364
#slice_var
347365
},
348366
}
367+
} else if is_primitive_type(&elem_type) {
368+
FieldInfo {
369+
ty: parse_quote! {
370+
&'a #elem_type
371+
},
372+
// Columns structs won't ever have primitive types, but this macro can be used on
373+
// other structs as well, to make it easy to borrow a struct from &[u8].
374+
// We just set length = 0 knowing that calling the width() method is undefined if
375+
// the struct has a primitive type.
376+
length_expr: quote! {
377+
0
378+
},
379+
prepare_subslice: quote! {
380+
let (#slice_var, slice) = slice.split_at(std::mem::size_of::<#elem_type>() #(* #dim_exprs)*);
381+
let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap();
382+
},
383+
initializer: quote! {
384+
#slice_var
385+
},
386+
}
349387
} else {
350-
panic!("Unsupported field type: {:?}", f.ty);
388+
panic!("Unsupported field type (in get_const_cols_ref_fields): {:?}", f.ty);
351389
}
352390
} else if derives_aligned_borrow {
353391
// treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config)
@@ -405,7 +443,7 @@ fn get_const_cols_ref_fields(
405443
},
406444
}
407445
} else {
408-
panic!("Unsupported field type: {:?}", f.ty);
446+
panic!("Unsupported field type (in get_mut_cols_ref_fields): {:?}", f.ty);
409447
}
410448
}
411449

@@ -485,8 +523,28 @@ fn get_mut_cols_ref_fields(
485523
#slice_var
486524
},
487525
}
526+
} else if is_primitive_type(&elem_type) {
527+
FieldInfo {
528+
ty: parse_quote! {
529+
&'a mut #elem_type
530+
},
531+
// Columns structs won't ever have primitive types, but this macro can be used on
532+
// other structs as well, to make it easy to borrow a struct from &[u8].
533+
// We just set length = 0 knowing that calling the width() method is undefined if
534+
// the struct has a primitive type.
535+
length_expr: quote! {
536+
0
537+
},
538+
prepare_subslice: quote! {
539+
let (#slice_var, slice) = slice.split_at_mut(std::mem::size_of::<#elem_type>() #(* #dim_exprs)*);
540+
let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap();
541+
},
542+
initializer: quote! {
543+
#slice_var
544+
},
545+
}
488546
} else {
489-
panic!("Unsupported field type: {:?}", f.ty);
547+
panic!("Unsupported field type (in get_mut_cols_ref_fields): {:?}", f.ty);
490548
}
491549
} else if derives_aligned_borrow {
492550
// treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config)
@@ -544,7 +602,7 @@ fn get_mut_cols_ref_fields(
544602
},
545603
}
546604
} else {
547-
panic!("Unsupported field type: {:?}", f.ty);
605+
panic!("Unsupported field type (in get_mut_cols_ref_fields): {:?}", f.ty);
548606
}
549607
}
550608

@@ -637,61 +695,3 @@ fn is_generic_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> bool {
637695
false
638696
}
639697
}
640-
641-
// Type of array dimension
642-
enum Dimension {
643-
ConstGeneric(syn::Expr),
644-
Other(syn::Expr),
645-
}
646-
647-
// Describes a nested array
648-
struct ArrayInfo {
649-
dims: Vec<Dimension>,
650-
elem_type: syn::Type,
651-
}
652-
653-
fn get_array_info(ty: &syn::Type, const_generics: &[&syn::Ident]) -> ArrayInfo {
654-
let dims = get_dims(ty, const_generics);
655-
let elem_type = get_elem_type(ty);
656-
ArrayInfo { dims, elem_type }
657-
}
658-
659-
fn get_elem_type(ty: &syn::Type) -> syn::Type {
660-
match ty {
661-
syn::Type::Array(array) => get_elem_type(array.elem.as_ref()),
662-
syn::Type::Path(_) => ty.clone(),
663-
_ => panic!("Unsupported type: {:?}", ty),
664-
}
665-
}
666-
667-
// Get a vector of the dimensions of the array
668-
// Each dimension is either a constant generic or a literal integer value
669-
fn get_dims(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec<Dimension> {
670-
get_dims_impl(ty, const_generics)
671-
.into_iter()
672-
.rev()
673-
.collect()
674-
}
675-
676-
fn get_dims_impl(ty: &syn::Type, const_generics: &[&syn::Ident]) -> Vec<Dimension> {
677-
match ty {
678-
syn::Type::Array(array) => {
679-
let mut dims = get_dims_impl(array.elem.as_ref(), const_generics);
680-
match &array.len {
681-
syn::Expr::Path(syn::ExprPath { path, .. }) => {
682-
let len_ident = path.get_ident();
683-
if len_ident.is_some() && const_generics.contains(&len_ident.unwrap()) {
684-
dims.push(Dimension::ConstGeneric(array.len.clone()));
685-
} else {
686-
dims.push(Dimension::Other(array.len.clone()));
687-
}
688-
}
689-
syn::Expr::Lit(expr_lit) => dims.push(Dimension::Other(expr_lit.clone().into())),
690-
_ => panic!("Unsupported array length type"),
691-
}
692-
dims
693-
}
694-
syn::Type::Path(_) => Vec::new(),
695-
_ => panic!("Unsupported field type"),
696-
}
697-
}
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use syn::{Expr, ExprBlock, ExprPath, Ident, Stmt, Type, TypePath};
2+
3+
pub fn is_primitive_type(ty: &Type) -> bool {
4+
match ty {
5+
Type::Path(TypePath { path, .. }) if path.segments.len() == 1 => {
6+
matches!(
7+
path.segments[0].ident.to_string().as_str(),
8+
"u8" | "u16"
9+
| "u32"
10+
| "u64"
11+
| "u128"
12+
| "usize"
13+
| "i8"
14+
| "i16"
15+
| "i32"
16+
| "i64"
17+
| "i128"
18+
| "isize"
19+
| "f32"
20+
| "f64"
21+
| "bool"
22+
| "char"
23+
)
24+
}
25+
_ => false,
26+
}
27+
}
28+
29+
// Type of array dimension
30+
pub enum Dimension {
31+
ConstGeneric(Expr),
32+
Other(Expr),
33+
}
34+
35+
// Describes a nested array
36+
pub struct ArrayInfo {
37+
pub dims: Vec<Dimension>,
38+
pub elem_type: Type,
39+
}
40+
41+
pub fn get_array_info(ty: &Type, const_generics: &[&Ident]) -> ArrayInfo {
42+
let dims = get_dims(ty, const_generics);
43+
let elem_type = get_elem_type(ty);
44+
ArrayInfo { dims, elem_type }
45+
}
46+
47+
fn get_elem_type(ty: &Type) -> Type {
48+
match ty {
49+
Type::Array(array) => get_elem_type(array.elem.as_ref()),
50+
Type::Path(_) => ty.clone(),
51+
_ => panic!("Unsupported type: {:?}", ty),
52+
}
53+
}
54+
55+
// Get a vector of the dimensions of the array
56+
// Each dimension is either a constant generic or a literal integer value
57+
fn get_dims(ty: &Type, const_generics: &[&Ident]) -> Vec<Dimension> {
58+
get_dims_impl(ty, const_generics)
59+
.into_iter()
60+
.rev()
61+
.collect()
62+
}
63+
64+
fn get_dims_impl(ty: &Type, const_generics: &[&Ident]) -> Vec<Dimension> {
65+
match ty {
66+
Type::Array(array) => {
67+
let mut dims = get_dims_impl(array.elem.as_ref(), const_generics);
68+
match &array.len {
69+
Expr::Block(syn::ExprBlock { block, .. }) => {
70+
if block.stmts.len() != 1 {
71+
panic!(
72+
"Expected exactly one statement in block, got: {:?}",
73+
block.stmts.len()
74+
);
75+
}
76+
if let Stmt::Expr(Expr::Path(expr_path), ..) = &block.stmts[0] {
77+
if let Some(len_ident) = expr_path.path.get_ident() {
78+
if const_generics.contains(&len_ident) {
79+
println!("Const generic new: {:?}", expr_path);
80+
dims.push(Dimension::ConstGeneric(expr_path.clone().into()));
81+
} else {
82+
dims.push(Dimension::Other(expr_path.clone().into()));
83+
}
84+
}
85+
}
86+
}
87+
Expr::Path(ExprPath { path, .. }) => {
88+
let len_ident = path.get_ident();
89+
if len_ident.is_some() && const_generics.contains(&len_ident.unwrap()) {
90+
println!("Const generic old: {:?}", array.len);
91+
dims.push(Dimension::ConstGeneric(array.len.clone()));
92+
} else {
93+
dims.push(Dimension::Other(array.len.clone()));
94+
}
95+
}
96+
Expr::Lit(expr_lit) => dims.push(Dimension::Other(expr_lit.clone().into())),
97+
_ => panic!("Unsupported array length type: {:?}", array.len),
98+
}
99+
dims
100+
}
101+
Type::Path(_) => Vec::new(),
102+
_ => panic!("Unsupported field type (in get_dims_impl)"),
103+
}
104+
}
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use openvm_circuit_primitives_derive::ColsRef;
2+
3+
trait ExampleConfig {
4+
const N: usize;
5+
}
6+
struct ExampleConfigImplA;
7+
impl ExampleConfig for ExampleConfigImplA {
8+
const N: usize = 5;
9+
}
10+
11+
#[allow(dead_code)]
12+
#[derive(ColsRef)]
13+
#[config(ExampleConfig)]
14+
struct ExampleCols<T, const N: usize> {
15+
arr: [T; N],
16+
// arr: [T; { N }],
17+
sum: T,
18+
// primitive: u32,
19+
// array_of_primitive: [u32; { N }],
20+
}
21+
22+
#[test]
23+
fn debug() {
24+
let input = [1, 2, 3, 4, 5, 15];
25+
let test: ExampleColsRef<u32> = ExampleColsRef::from::<ExampleConfigImplA>(&input);
26+
println!("{}, {}", test.arr, test.sum);
27+
}

crates/circuits/sha2-air/Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,17 @@ edition.workspace = true
88
openvm-circuit-primitives = { workspace = true }
99
openvm-stark-backend = { workspace = true }
1010
openvm-circuit-primitives-derive = { workspace = true }
11+
1112
sha2 = { version = "0.10", features = ["compress"] }
1213
rand.workspace = true
1314
ndarray.workspace = true
1415
num_enum = { workspace = true }
16+
itertools = { workspace = true }
1517

1618
[dev-dependencies]
1719
openvm-stark-sdk = { workspace = true }
1820
openvm-circuit = { workspace = true, features = ["test-utils"] }
1921

2022
[features]
2123
default = ["parallel"]
22-
parallel = ["openvm-stark-backend/parallel"]
24+
parallel = ["openvm-stark-backend/parallel"]

0 commit comments

Comments
 (0)