From 2d76977e6c26936f21f2d81f1b53e85fe1dd84d8 Mon Sep 17 00:00:00 2001 From: Robert Zakrzewski Date: Tue, 14 Jan 2025 13:43:37 +0100 Subject: [PATCH 1/4] Add Builtin trait and refactor all builtin functions to use it (#244) Refactor builtin function to Builtin trait --- src/stdlib/bits.rs | 151 ++++++++++++++------------ src/stdlib/builtins.rs | 234 +++++++++++++++++++++++------------------ src/stdlib/crypto.rs | 21 +++- src/stdlib/int.rs | 102 +++++++++--------- 4 files changed, 281 insertions(+), 227 deletions(-) diff --git a/src/stdlib/bits.rs b/src/stdlib/bits.rs index ac5ca3d01..8d87fe537 100644 --- a/src/stdlib/bits.rs +++ b/src/stdlib/bits.rs @@ -12,10 +12,7 @@ use crate::{ var::{ConstOrCell, Value, Var}, }; -use super::{FnInfoType, Module}; - -const NTH_BIT_FN: &str = "nth_bit(val: Field, const nth: Field) -> Field"; -const CHECK_FIELD_SIZE_FN: &str = "check_field_size(cmp: Field)"; +use super::{builtins::Builtin, FnInfoType, Module}; pub struct BitsLib {} @@ -24,81 +21,95 @@ impl Module for BitsLib { fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (NTH_BIT_FN, nth_bit, false), - (CHECK_FIELD_SIZE_FN, check_field_size, false), + (NthBitFn::SIGNATURE, NthBitFn::builtin, false), + ( + CheckFieldSizeFn::SIGNATURE, + CheckFieldSizeFn::builtin, + false, + ), ] } } -fn nth_bit( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // should be two input vars - assert_eq!(vars.len(), 2); - - // these should be type checked already, unless it is called by other low level functions - // eg. builtins - let var_info = &vars[0]; - let val = &var_info.var; - assert_eq!(val.len(), 1); - - let var_info = &vars[1]; - let nth = &var_info.var; - assert_eq!(nth.len(), 1); - - let nth: usize = match &nth[0] { - ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), - ConstOrCell::Const(cst) => cst.to_u64() as usize, - }; - - let val = match &val[0] { - ConstOrCell::Cell(cvar) => cvar.clone(), - ConstOrCell::Const(cst) => { - // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var - let bit = cst.to_bits(); - return Ok(Some(Var::new_cvar( - ConstOrCell::Const(B::Field::from(bit[nth])), - span, - ))); - } - }; +struct NthBitFn {} +struct CheckFieldSizeFn {} + +impl Builtin for NthBitFn { + const SIGNATURE: &'static str = "nth_bit(val: Field, const nth: Field) -> Field"; + + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // should be two input vars + assert_eq!(vars.len(), 2); + + // these should be type checked already, unless it is called by other low level functions + // eg. builtins + let var_info = &vars[0]; + let val = &var_info.var; + assert_eq!(val.len(), 1); + + let var_info = &vars[1]; + let nth = &var_info.var; + assert_eq!(nth.len(), 1); + + let nth: usize = match &nth[0] { + ConstOrCell::Cell(_) => unreachable!("nth should be a constant"), + ConstOrCell::Const(cst) => cst.to_u64() as usize, + }; + + let val = match &val[0] { + ConstOrCell::Cell(cvar) => cvar.clone(), + ConstOrCell::Const(cst) => { + // directly return the nth bit without adding symbolic value as it doesn't depend on a cell var + let bit = cst.to_bits(); + return Ok(Some(Var::new_cvar( + ConstOrCell::Const(B::Field::from(bit[nth])), + span, + ))); + } + }; - let bit = compiler - .backend - .new_internal_var(Value::NthBit(val.clone(), nth), span); + let bit = compiler + .backend + .new_internal_var(Value::NthBit(val.clone(), nth), span); - Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) + Ok(Some(Var::new(vec![ConstOrCell::Cell(bit)], span))) + } } -// Ensure that the field size is not exceeded -fn check_field_size( - _compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - let var = &vars[0].var[0]; - let bit_len = B::Field::MODULUS_BIT_SIZE as u64; - - match var { - ConstOrCell::Const(cst) => { - let to_cmp = cst.to_u64(); - if to_cmp >= bit_len { - return Err(Error::new( - "constraint-generation", - ErrorKind::AssertionFailed, - span, - )); +impl Builtin for CheckFieldSizeFn { + const SIGNATURE: &'static str = "check_field_size(cmp: Field)"; + + fn builtin( + _compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + let var = &vars[0].var[0]; + let bit_len = B::Field::MODULUS_BIT_SIZE as u64; + + match var { + ConstOrCell::Const(cst) => { + let to_cmp = cst.to_u64(); + if to_cmp >= bit_len { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + )); + } + Ok(None) } - Ok(None) + ConstOrCell::Cell(_) => Err(Error::new( + "constraint-generation", + ErrorKind::ExpectedConstant, + span, + )), } - ConstOrCell::Cell(_) => Err(Error::new( - "constraint-generation", - ErrorKind::ExpectedConstant, - span, - )), } } diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index c487237ac..701d0260d 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -23,9 +23,6 @@ use super::{FnInfoType, Module}; pub const QUALIFIED_BUILTINS: &str = "std/builtins"; pub const BUILTIN_FN_NAMES: [&str; 3] = ["assert", "assert_eq", "log"]; -const ASSERT_FN: &str = "assert(condition: Bool)"; -const ASSERT_EQ_FN: &str = "assert_eq(lhs: Field, rhs: Field)"; -const LOG_FN: &str = "log(var: Field)"; pub struct BuiltinsLib {} impl Module for BuiltinsLib { @@ -33,10 +30,10 @@ impl Module for BuiltinsLib { fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { vec![ - (ASSERT_FN, assert_fn, false), - (ASSERT_EQ_FN, assert_eq_fn, true), + (AssertFn::SIGNATURE, AssertFn::builtin, false), + (AssertEqFn::SIGNATURE, AssertEqFn::builtin, true), // true -> skip argument type checking for log - (LOG_FN, log_fn, true), + (LogFn::SIGNATURE, LogFn::builtin, true), ] } } @@ -63,7 +60,7 @@ fn assert_eq_values( match typ { // Field and Bool has the same logic - TyKind::Field { .. } | TyKind::Bool => { + TyKind::Field { .. } | TyKind::Bool | TyKind::String(..) => { let lhs_var = &lhs_info.var[0]; let rhs_var = &rhs_info.var[0]; match (lhs_var, rhs_var) { @@ -146,114 +143,143 @@ fn assert_eq_values( comparisons } -/// Asserts that two vars are equal. -fn assert_eq_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get two vars - assert_eq!(vars.len(), 2); - let lhs_info = &vars[0]; - let rhs_info = &vars[1]; - - // get types of both arguments - let lhs_type = lhs_info.typ.as_ref().ok_or_else(|| { - Error::new( - "constraint-generation", - ErrorKind::UnexpectedError("No type info for lhs of assertion"), - span, - ) - })?; - - let rhs_type = rhs_info.typ.as_ref().ok_or_else(|| { - Error::new( - "constraint-generation", - ErrorKind::UnexpectedError("No type info for rhs of assertion"), - span, - ) - })?; - - // they have the same type - if !lhs_type.match_expected(rhs_type, false) { - return Err(Error::new( - "constraint-generation", - ErrorKind::AssertEqTypeMismatch(lhs_type.clone(), rhs_type.clone()), - span, - )); - } +pub trait Builtin { + const SIGNATURE: &'static str; - // first collect all comparisons needed - let comparisons = assert_eq_values(compiler, lhs_info, rhs_info, lhs_type, span); + fn builtin( + compiler: &mut CircuitWriter, + generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>>; +} - // then add all the constraints - for comparison in comparisons { - match comparison { - Comparison::Vars(lhs, rhs) => { - compiler.backend.assert_eq_var(&lhs, &rhs, span); - } - Comparison::VarConst(var, constant) => { - compiler.backend.assert_eq_const(&var, constant, span); - } - Comparison::Constants(a, b) => { - if a != b { - return Err(Error::new( - "constraint-generation", - ErrorKind::AssertionFailed, - span, - )); +struct AssertEqFn {} +struct AssertFn {} +struct LogFn {} + +impl Builtin for AssertEqFn { + const SIGNATURE: &'static str = "assert_eq(lhs: Field, rhs: Field)"; + + /// Asserts that two vars are equal. + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // we get two vars + assert_eq!(vars.len(), 2); + let lhs_info = &vars[0]; + let rhs_info = &vars[1]; + + // get types of both arguments + let lhs_type = lhs_info.typ.as_ref().ok_or_else(|| { + Error::new( + "constraint-generation", + ErrorKind::UnexpectedError("No type info for lhs of assertion"), + span, + ) + })?; + + let rhs_type = rhs_info.typ.as_ref().ok_or_else(|| { + Error::new( + "constraint-generation", + ErrorKind::UnexpectedError("No type info for rhs of assertion"), + span, + ) + })?; + + // they have the same type + if !lhs_type.match_expected(rhs_type, false) { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertEqTypeMismatch(lhs_type.clone(), rhs_type.clone()), + span, + )); + } + + // first collect all comparisons needed + let comparisons = assert_eq_values(compiler, lhs_info, rhs_info, lhs_type, span); + + // then add all the constraints + for comparison in comparisons { + match comparison { + Comparison::Vars(lhs, rhs) => { + compiler.backend.assert_eq_var(&lhs, &rhs, span); + } + Comparison::VarConst(var, constant) => { + compiler.backend.assert_eq_const(&var, constant, span); + } + Comparison::Constants(a, b) => { + if a != b { + return Err(Error::new( + "constraint-generation", + ErrorKind::AssertionFailed, + span, + )); + } } } } - } - Ok(None) + Ok(None) + } } -/// Asserts that a condition is true. -fn assert_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get a single var - assert_eq!(vars.len(), 1); - - // of type bool - let var_info = &vars[0]; - assert!(matches!(var_info.typ, Some(TyKind::Bool))); - - // of only one field element - let var = &var_info.var; - assert_eq!(var.len(), 1); - let cond = &var[0]; - - match cond { - ConstOrCell::Const(cst) => { - assert!(cst.is_one()); - } - ConstOrCell::Cell(cvar) => { - let one = B::Field::one(); - compiler.backend.assert_eq_const(cvar, one, span); +impl Builtin for AssertFn { + const SIGNATURE: &'static str = "assert(condition: Bool)"; + + /// Asserts that a condition is true. + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result::Field, ::Var>>> { + // we get a single var + assert_eq!(vars.len(), 1); + + // of type bool + let var_info = &vars[0]; + assert!(matches!(var_info.typ, Some(TyKind::Bool))); + + // of only one field element + let var = &var_info.var; + assert_eq!(var.len(), 1); + let cond = &var[0]; + + match cond { + ConstOrCell::Const(cst) => { + assert!(cst.is_one()); + } + ConstOrCell::Cell(cvar) => { + let one = B::Field::one(); + compiler.backend.assert_eq_const(cvar, one, span); + } } - } - Ok(None) + Ok(None) + } } -/// Logging -fn log_fn( - compiler: &mut CircuitWriter, - generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - for var in vars { - // todo: will need to support string argument in order to customize msg - compiler.backend.log_var(var, span); - } +impl Builtin for LogFn { + // todo: currently only supports a single field var + // to support all the types, we can bypass the type check for this log function for now + const SIGNATURE: &'static str = "log(var: Field)"; + + /// Logging + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + for var in vars { + // todo: will need to support string argument in order to customize msg + compiler.backend.log_var(var, span); + } - Ok(None) + Ok(None) + } } diff --git a/src/stdlib/crypto.rs b/src/stdlib/crypto.rs index 66113cddd..13ff91a86 100644 --- a/src/stdlib/crypto.rs +++ b/src/stdlib/crypto.rs @@ -1,14 +1,27 @@ -use super::{FnInfoType, Module}; +use super::{builtins::Builtin, FnInfoType, Module}; use crate::backends::Backend; -const POSEIDON_FN: &str = "poseidon(input: [Field; 2]) -> [Field; 3]"; - pub struct CryptoLib {} impl Module for CryptoLib { const MODULE: &'static str = "crypto"; fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { - vec![(POSEIDON_FN, B::poseidon(), false)] + vec![(PoseidonFn::SIGNATURE, PoseidonFn::builtin, false)] + } +} + +struct PoseidonFn {} + +impl Builtin for PoseidonFn { + const SIGNATURE: &'static str = "poseidon(input: [Field; 2]) -> [Field; 3]"; + + fn builtin( + compiler: &mut crate::circuit_writer::CircuitWriter, + generics: &crate::parser::types::GenericParameters, + vars: &[crate::circuit_writer::VarInfo], + span: crate::constants::Span, + ) -> crate::error::Result>> { + B::poseidon()(compiler, generics, vars, span) } } diff --git a/src/stdlib/int.rs b/src/stdlib/int.rs index 03c574890..94f40ff89 100644 --- a/src/stdlib/int.rs +++ b/src/stdlib/int.rs @@ -11,9 +11,7 @@ use crate::{ var::{ConstOrCell, Value, Var}, }; -use super::{FnInfoType, Module}; - -const DIVMOD_FN: &str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; +use super::{builtins::Builtin, FnInfoType, Module}; pub struct IntLib {} @@ -21,57 +19,63 @@ impl Module for IntLib { const MODULE: &'static str = "int"; fn get_fns() -> Vec<(&'static str, FnInfoType, bool)> { - vec![(DIVMOD_FN, divmod_fn, false)] + vec![(DivmodFn::SIGNATURE, DivmodFn::builtin, false)] } } /// Divides two field elements and returns the quotient and remainder. -fn divmod_fn( - compiler: &mut CircuitWriter, - _generics: &GenericParameters, - vars: &[VarInfo], - span: Span, -) -> Result>> { - // we get two vars - let dividend_info = &vars[0]; - let divisor_info = &vars[1]; - - // retrieve the values - let dividend_var = ÷nd_info.var[0]; - let divisor_var = &divisor_info.var[0]; - - match (dividend_var, divisor_var) { - // two constants - (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { - // convert to bigints - let a = a.to_biguint(); - let b = b.to_biguint(); - - let quotient = a.clone() / b.clone(); - let remainder = a % b; - - // convert back to fields - let quotient = B::Field::from_biguint("ient).unwrap(); - let remainder = B::Field::from_biguint(&remainder).unwrap(); - - Ok(Some(Var::new( - vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], - span, - ))) - } +struct DivmodFn {} + +impl Builtin for DivmodFn { + const SIGNATURE: &'static str = "divmod(dividend: Field, divisor: Field) -> [Field; 2]"; + + fn builtin( + compiler: &mut CircuitWriter, + _generics: &GenericParameters, + vars: &[VarInfo], + span: Span, + ) -> Result>> { + // we get two vars + let dividend_info = &vars[0]; + let divisor_info = &vars[1]; + + // retrieve the values + let dividend_var = ÷nd_info.var[0]; + let divisor_var = &divisor_info.var[0]; + + match (dividend_var, divisor_var) { + // two constants + (ConstOrCell::Const(a), ConstOrCell::Const(b)) => { + // convert to bigints + let a = a.to_biguint(); + let b = b.to_biguint(); + + let quotient = a.clone() / b.clone(); + let remainder = a % b; + + // convert back to fields + let quotient = B::Field::from_biguint("ient).unwrap(); + let remainder = B::Field::from_biguint(&remainder).unwrap(); + + Ok(Some(Var::new( + vec![ConstOrCell::Const(quotient), ConstOrCell::Const(remainder)], + span, + ))) + } + + _ => { + let quotient = compiler + .backend + .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); + let remainder = compiler + .backend + .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); - _ => { - let quotient = compiler - .backend - .new_internal_var(Value::Div(dividend_var.clone(), divisor_var.clone()), span); - let remainder = compiler - .backend - .new_internal_var(Value::Mod(dividend_var.clone(), divisor_var.clone()), span); - - Ok(Some(Var::new( - vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], - span, - ))) + Ok(Some(Var::new( + vec![ConstOrCell::Cell(quotient), ConstOrCell::Cell(remainder)], + span, + ))) + } } } } From 1bcd765d78eae462516471c10a19d2746b90e48a Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma <114555115+0xnullifier@users.noreply.github.com> Date: Fri, 17 Jan 2025 18:26:15 +0530 Subject: [PATCH 2/4] Adds tuple type (#261) * support tuple type --- examples/log.no | 6 +- examples/tuple.no | 47 +++++++++ src/backends/mod.rs | 34 +++++- src/circuit_writer/ir.rs | 50 ++++++--- src/circuit_writer/writer.rs | 63 ++++++++--- src/error.rs | 7 +- src/inputs.rs | 16 +++ src/mast/mod.rs | 61 +++++++++-- src/name_resolution/context.rs | 5 + src/name_resolution/expr.rs | 9 +- src/parser/expr.rs | 62 +++++++++-- src/parser/types.rs | 83 ++++++++++++++- src/stdlib/builtins.rs | 26 +++++ src/type_checker/checker.rs | 62 ++++++++--- src/type_checker/mod.rs | 1 + src/utils/mod.rs | 186 ++++++++++++++++++++++----------- 16 files changed, 591 insertions(+), 127 deletions(-) create mode 100644 examples/tuple.no diff --git a/examples/log.no b/examples/log.no index a206495d3..c9072c268 100644 --- a/examples/log.no +++ b/examples/log.no @@ -13,7 +13,11 @@ fn main(pub public_input: Field) -> Field { log(arr); let thing = Thing { xx : public_input , yy: public_input + 1}; - log("formatted string with a number {} boolean {} arr {} struct {}" , 1234,true, arr , thing); + + log(thing); + + let tup = (1 , true , thing); + log("formatted string with a number {} boolean {} arr {} tuple {} struct {}" , 1234 , true, arr, tup, thing); return public_input + 1; } \ No newline at end of file diff --git a/examples/tuple.no b/examples/tuple.no new file mode 100644 index 000000000..867c91998 --- /dev/null +++ b/examples/tuple.no @@ -0,0 +1,47 @@ +struct Thing { + xx: Field, + tuple_field: (Field,Bool) +} + +// return tuples from functions +fn Thing.new(xx: Field , tup: (Field,Bool)) -> (Thing , (Field,Bool)) { + return ( + Thing { + xx: xx, + tuple_field:tup + }, + tup + ); +} + +fn generic_array_tuple_test(var : ([[Field;NN];LEN],Bool)) -> (Field , [Field;NN]) { + let zero = 0; + let result = if var[1] {var[0][LEN - 1][NN - 1]} else { var[0][LEN - 2][NN - 2] }; + return (result , var[0][LEN - 1]); +} + +// xx should be 0 +fn main(pub xx: [Field; 2]) -> Field { + // creation of new tuple with different types + let tup = (1, true); + + // create nested tuples + let nested_tup = ((false, [1,2,3]), 1); + log(nested_tup); // (1, (true , [1,2,3])) + + let incr = nested_tup[1]; // 1 + + // tuples can be input to function + let mut thing = Thing.new(xx[1] , (xx[0] , xx[0] == 0)); + + // you can access a tuple type just like you access a array + thing[0].tuple_field[0] += incr; + log(thing[0].tuple_field[0]); + let new_allocation = [xx,xx]; + let ret = generic_array_tuple_test((new_allocation, true)); + + assert_eq(thing[0].tuple_field[0] , 1); + log(ret[1]); // logs xx i.e [0,123] + + return ret[0]; +} \ No newline at end of file diff --git a/src/backends/mod.rs b/src/backends/mod.rs index a8ca9e849..0854c2b83 100644 --- a/src/backends/mod.rs +++ b/src/backends/mod.rs @@ -14,7 +14,7 @@ use crate::{ helpers::PrettyField, imports::FnHandle, parser::types::TyKind, - utils::{log_array_type, log_custom_type, log_string_type}, + utils::{log_array_or_tuple_type, log_custom_type, log_string_type}, var::{ConstOrCell, Value, Var}, witness::WitnessEnv, }; @@ -467,8 +467,20 @@ pub trait Backend: Clone { // Array Some(TyKind::Array(b, s)) => { - let (output, remaining) = - log_array_type(self, &var_info.var.cvars, b, *s, witness_env, typed, span)?; + let mut typs = Vec::with_capacity(*s as usize); + for _ in 0..(*s) { + typs.push((**b).clone()); + } + let (output, remaining) = log_array_or_tuple_type( + self, + &var_info.var.cvars, + &typs, + *s, + witness_env, + typed, + span, + false, + )?; assert!(remaining.is_empty()); println!("{dbg_msg}{}", output); } @@ -504,6 +516,22 @@ pub trait Backend: Clone { println!("{dbg_msg}{}", output); } + Some(TyKind::Tuple(typs)) => { + let len = typs.len(); + let (output, remaining) = log_array_or_tuple_type( + self, + &var_info.var.cvars, + &typs, + len as u32, + witness_env, + typed, + span, + true, + ) + .unwrap(); + assert!(remaining.is_empty()); + println!("{dbg_msg}{}", output); + } None => { return Err(Error::new( "log", diff --git a/src/circuit_writer/ir.rs b/src/circuit_writer/ir.rs index 4bdf45581..a8aed8fb8 100644 --- a/src/circuit_writer/ir.rs +++ b/src/circuit_writer/ir.rs @@ -972,11 +972,11 @@ impl IRWriter { Ok(Some(res)) } - ExprKind::ArrayAccess { array, idx } => { - // retrieve var of array + ExprKind::ArrayOrTupleAccess { container, idx } => { + // retrieve var of container let var = self - .compute_expr(fn_env, array)? - .expect("array access on non-array"); + .compute_expr(fn_env, container)? + .expect("container access on non-container"); // compute the index let idx_var = self @@ -987,10 +987,15 @@ impl IRWriter { .ok_or_else(|| self.error(ErrorKind::ExpectedConstant, expr.span))?; let idx: usize = idx.try_into().unwrap(); - // retrieve the type of the elements in the array - let array_typ = self.expr_type(array).expect("cannot find type of array"); + // retrieve the type of the elements in the container + let container_typ = self + .expr_type(container) + .expect("cannot find type of container"); - let elem_type = match array_typ { + // actual starting index for narrowing the var depends on the cotainer + // for arrays it is just idx * elem_size as all elements are of same size + // while for tuples we have to sum the sizes of all types up to that index + let (start, len) = match container_typ { TyKind::Array(ty, array_len) => { if idx >= (*array_len as usize) { return Err(self.error( @@ -998,18 +1003,25 @@ impl IRWriter { expr.span, )); } - ty + let len = self.size_of(ty); + let start = idx * self.size_of(ty); + (start, len) + } + + TyKind::Tuple(typs) => { + let mut starting_idx = 0; + for i in 0..idx { + starting_idx += self.size_of(&typs[i]); + } + (starting_idx, self.size_of(&typs[idx])) } _ => Err(Error::new( "compute-expr", - ErrorKind::UnexpectedError("expected array"), + ErrorKind::UnexpectedError("expected container"), expr.span, ))?, }; - // compute the size of each element in the array - let len = self.size_of(elem_type); - // compute the real index let start = idx * len; @@ -1074,6 +1086,20 @@ impl IRWriter { let var = VarOrRef::Var(Var::new(cvars, expr.span)); Ok(Some(var)) } + + ExprKind::TupleDeclaration(items) => { + let mut cvars = vec![]; + + for item in items { + let var = self.compute_expr(fn_env, item)?.unwrap(); + let to_extend = var.value(self, fn_env).cvars.clone(); + cvars.extend(to_extend); + } + + let var = VarOrRef::Var(Var::new(cvars, expr.span)); + + Ok(Some(var)) + } } } diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 8380707ee..5cbb1815b 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -332,6 +332,15 @@ impl CircuitWriter { unreachable!("generic array should have been resolved") } TyKind::String(_) => todo!("String type is not supported for constraints"), + TyKind::Tuple(types) => { + let mut offset = 0; + for ty in types { + let size = self.size_of(ty); + let slice = &input[offset..(offset + size)]; + self.constrain_inputs_to_main(slice, input_typ, span)?; + offset += size; + } + } }; Ok(()) } @@ -726,11 +735,11 @@ impl CircuitWriter { Ok(Some(res)) } - ExprKind::ArrayAccess { array, idx } => { - // retrieve var of array + ExprKind::ArrayOrTupleAccess { container, idx } => { + // retrieve var of container let var = self - .compute_expr(fn_env, array)? - .expect("array access on non-array"); + .compute_expr(fn_env, container)? + .expect("container access on non-container"); // compute the index let idx_var = self @@ -742,10 +751,15 @@ impl CircuitWriter { let idx: BigUint = idx.into(); let idx: usize = idx.try_into().unwrap(); - // retrieve the type of the elements in the array - let array_typ = self.expr_type(array).expect("cannot find type of array"); + // retrieve the type of the elements in the container + let container_typ = self + .expr_type(container) + .expect("cannot find type of container"); - let elem_type = match array_typ { + // actual starting index for narrowing the var depends on the cotainer + // for arrays it is just idx * elem_size as all elements are of same size + // while for tuples we have to sum the sizes of all types up to that index + let (start, len) = match container_typ { TyKind::Array(ty, array_len) => { if idx >= (*array_len as usize) { return Err(self.error( @@ -753,21 +767,25 @@ impl CircuitWriter { expr.span, )); } - ty + let len = self.size_of(ty); + let start = idx * self.size_of(ty); + (start, len) + } + + TyKind::Tuple(typs) => { + let mut start = 0; + for i in 0..idx { + start += self.size_of(&typs[i]); + } + (start, self.size_of(&typs[idx])) } _ => Err(Error::new( "compute-expr", - ErrorKind::UnexpectedError("expected array"), + ErrorKind::UnexpectedError("expected container"), expr.span, ))?, }; - // compute the size of each element in the array - let len = self.size_of(elem_type); - - // compute the real index - let start = idx * len; - // out-of-bound checks if start >= var.len() || start + len > var.len() { return Err(self.error( @@ -830,6 +848,21 @@ impl CircuitWriter { let var = VarOrRef::Var(Var::new(cvars, expr.span)); Ok(Some(var)) } + // exact copy of Array Declaration there is nothing really different at when looking it from a expression level + // as both of them are just `Vec` + ExprKind::TupleDeclaration(items) => { + let mut cvars = vec![]; + + for item in items { + let var = self.compute_expr(fn_env, item)?.unwrap(); + let to_extend = var.value(self, fn_env).cvars.clone(); + cvars.extend(to_extend); + } + + let var = VarOrRef::Var(Var::new(cvars, expr.span)); + + Ok(Some(var)) + } } } diff --git a/src/error.rs b/src/error.rs index e39b50636..9ac60c9a8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -268,6 +268,9 @@ pub enum ErrorKind { #[error("array accessed at index {0} is out of bounds (max allowed index is {1})")] ArrayIndexOutOfBounds(usize, usize), + #[error("tuple accessed at index {0} is out of bounds (max allowed index is {1})")] + TupleIndexOutofBounds(usize, usize), + #[error( "one-letter variables or types are not allowed. Best practice is to use descriptive names" )] @@ -327,8 +330,8 @@ pub enum ErrorKind { #[error("field access can only be applied on custom structs")] FieldAccessOnNonCustomStruct, - #[error("array access can only be performed on arrays")] - ArrayAccessOnNonArray, + #[error("array like access can only be performed on arrays or tuples")] + AccessOnNonCollection, #[error("struct `{0}` does not exist (are you sure it is defined?)")] UndefinedStruct(String), diff --git a/src/inputs.rs b/src/inputs.rs index d4dbaee5f..9994fa44f 100644 --- a/src/inputs.rs +++ b/src/inputs.rs @@ -151,6 +151,22 @@ impl CompiledCircuit { Ok(res) } + // parsing for tuple function inputs from json + (TyKind::Tuple(types), Value::Array(values)) => { + if values.len() != types.len() { + Err(ParsingError::ArraySizeMismatch( + values.len(), + types.len() as usize, + ))? + } + // making a vec with capacity allows for less number of reallocations + let mut res = Vec::with_capacity(values.len()); + for (ty, val) in types.iter().zip(values) { + let el = self.parse_single_input(val, ty)?; + res.extend(el); + } + Ok(res) + } (expected, observed) => { return Err(ParsingError::MismatchJsonArgument( expected.clone(), diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 31e40a3da..16830d83d 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -230,6 +230,10 @@ impl FnSig { val.to_u32().expect("array size exceeded u32"), ) } + TyKind::Tuple(typs) => { + let typs: Vec = typs.iter().map(|ty| self.resolve_type(ty, ctx)).collect(); + TyKind::Tuple(typs) + } _ => typ.clone(), } } @@ -251,6 +255,18 @@ impl FnSig { observed_arg.expr.span, )?; } + // if generics in tuple + (TyKind::Tuple(sig_arg_typs), TyKind::Tuple(observed_arg_typs)) => { + for (sig_arg_typ, observed_arg_typ) in + sig_arg_typs.iter().zip(observed_arg_typs) + { + self.resolve_generic_array( + &sig_arg_typ, + &observed_arg_typ, + observed_arg.expr.span, + )?; + } + } // const NN: Field _ => { let cst = observed_arg.constant.clone(); @@ -458,6 +474,7 @@ impl Mast { } TyKind::Bool => 1, TyKind::String(s) => s.len(), + TyKind::Tuple(typs) => typs.iter().map(|ty| self.size_of(ty)).sum(), } } } @@ -943,18 +960,26 @@ fn monomorphize_expr( res } - ExprKind::ArrayAccess { array, idx } => { + ExprKind::ArrayOrTupleAccess { container, idx } => { // get type of lhs - let array_mono = monomorphize_expr(ctx, array, mono_fn_env)?; + let array_mono = monomorphize_expr(ctx, container, mono_fn_env)?; let id_mono = monomorphize_expr(ctx, idx, mono_fn_env)?; // get type of element let el_typ = match array_mono.typ { Some(TyKind::Array(typkind, _)) => Some(*typkind), + Some(TyKind::Tuple(typs)) => match &idx.kind { + ExprKind::BigUInt(index) => Some(typs[index.to_usize().unwrap()].clone()), + _ => Err(Error::new( + "Non constant container access", + ErrorKind::ExpectedConstant, + expr.span, + ))?, + }, _ => Err(Error::new( - "Array Access", + "Container Access", ErrorKind::UnexpectedError( - "Attempting to access array when type is not an array", + "Attempting to access container when type is not an container", ), expr.span, ))?, @@ -962,8 +987,8 @@ fn monomorphize_expr( let mexpr = expr.to_mast( ctx, - &ExprKind::ArrayAccess { - array: Box::new(array_mono.expr), + &ExprKind::ArrayOrTupleAccess { + container: Box::new(array_mono.expr), idx: Box::new(id_mono.expr), }, ); @@ -1137,6 +1162,30 @@ fn monomorphize_expr( return Err(error(ErrorKind::InvalidArraySize, expr.span)); } } + + ExprKind::TupleDeclaration(items) => { + // checking the size of the tuple + let _: u32 = items.len().try_into().expect("tuple too large"); + + let items_mono: Vec = items + .iter() + .map(|item| monomorphize_expr(ctx, item, mono_fn_env).unwrap()) + .collect(); + + let typs: Vec = items_mono + .iter() + .cloned() + .map(|item_mono| item_mono.typ.unwrap()) + .collect(); + + let mexpr = expr.to_mast( + ctx, + &ExprKind::ArrayDeclaration(items_mono.into_iter().map(|e| e.expr).collect()), + ); + + let typ = TyKind::Tuple(typs); + ExprMonoInfo::new(mexpr, Some(typ), None) + } }; if let Some(typ) = &expr_mono.typ { diff --git a/src/name_resolution/context.rs b/src/name_resolution/context.rs index f06338772..23314adc1 100644 --- a/src/name_resolution/context.rs +++ b/src/name_resolution/context.rs @@ -134,6 +134,11 @@ impl NameResCtx { TyKind::GenericSizedArray(typ_kind, _) => self.resolve_typ_kind(typ_kind)?, TyKind::Bool => (), TyKind::String { .. } => (), + TyKind::Tuple(typs) => { + typs.iter_mut() + .for_each(|typ| self.resolve_typ_kind(typ).unwrap()); + () + } }; Ok(()) diff --git a/src/name_resolution/expr.rs b/src/name_resolution/expr.rs index 8fb34d438..b48b644b6 100644 --- a/src/name_resolution/expr.rs +++ b/src/name_resolution/expr.rs @@ -71,8 +71,8 @@ impl NameResCtx { ExprKind::Variable { module, name: _ } => { self.resolve(module, false)?; } - ExprKind::ArrayAccess { array, idx } => { - self.resolve_expr(array)?; + ExprKind::ArrayOrTupleAccess { container, idx } => { + self.resolve_expr(container)?; self.resolve_expr(idx)?; } ExprKind::ArrayDeclaration(items) => { @@ -105,6 +105,11 @@ impl NameResCtx { self.resolve_expr(then_)?; self.resolve_expr(else_)?; } + ExprKind::TupleDeclaration(items) => { + for expr in items { + self.resolve_expr(expr)?; + } + } }; Ok(()) diff --git a/src/parser/expr.rs b/src/parser/expr.rs index 8d1b5308b..4cb1ce62e 100644 --- a/src/parser/expr.rs +++ b/src/parser/expr.rs @@ -97,9 +97,14 @@ pub enum ExprKind { // TODO: change to `identifier` or `path`? Variable { module: ModulePath, name: Ident }, - /// An array access, for example: + /// An array or tuple access, for example: /// `lhs[idx]` - ArrayAccess { array: Box, idx: Box }, + /// As both almost work identical to each other expersion level we handle the cases for each container in the + /// circuit writers and typecheckers + ArrayOrTupleAccess { + container: Box, + idx: Box, + }, /// `[ ... ]` ArrayDeclaration(Vec), @@ -124,6 +129,9 @@ pub enum ExprKind { }, /// Any string literal StringLiteral(String), + + ///Tuple Declaration + TupleDeclaration(Vec), } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -208,6 +216,40 @@ impl Expr { // parenthesis TokenKind::LeftParen => { let mut expr = Expr::parse(ctx, tokens)?; + + // check if it is a tuple declaration + let second_token = tokens.peek(); + + match second_token { + // this means a tuple declaration + Some(Token { + kind: TokenKind::Comma, + span: _, + }) => { + let mut items = vec![expr]; + let last_span: Span; + loop { + let token = tokens.bump_err(ctx, ErrorKind::InvalidEndOfLine)?; + match token.kind { + TokenKind::RightParen => { + last_span = token.span; + break; + } + TokenKind::Comma => (), + _ => return Err(ctx.error(ErrorKind::InvalidEndOfLine, token.span)), + } + let item = Expr::parse(ctx, tokens)?; + items.push(item); + } + return Ok(Expr::new( + ctx, + ExprKind::TupleDeclaration(items), + span.merge_with(last_span), + )); + } + _ => (), + } + tokens.bump_expected(ctx, TokenKind::RightParen)?; if let ExprKind::BinaryOp { protected, .. } = &mut expr.kind { @@ -244,7 +286,7 @@ impl Expr { | ExprKind::Bool { .. } | ExprKind::BigUInt { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } | ExprKind::StringLiteral { .. } ) { Err(Error::new( @@ -278,7 +320,7 @@ impl Expr { | ExprKind::Bool { .. } | ExprKind::BigUInt { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } | ExprKind::StringLiteral { .. } ) { Err(Error::new( @@ -449,7 +491,7 @@ impl Expr { if !matches!( &self.kind, ExprKind::Variable { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } | ExprKind::FieldAccess { .. }, ) { return Err(ctx.error( @@ -596,7 +638,7 @@ impl Expr { parse_type_declaration(ctx, tokens, ident)? } - // array access + // array or tuple access Some(Token { kind: TokenKind::LeftBracket, .. @@ -608,7 +650,7 @@ impl Expr { self.kind, ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } ) { Err(Error::new( "parse_rhs - left bracket", @@ -628,8 +670,8 @@ impl Expr { Expr::new( ctx, - ExprKind::ArrayAccess { - array: Box::new(self), + ExprKind::ArrayOrTupleAccess { + container: Box::new(self), idx: Box::new(idx), }, span, @@ -682,7 +724,7 @@ impl Expr { &self.kind, ExprKind::FieldAccess { .. } | ExprKind::Variable { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } ) { let span = self.span.merge_with(period.span); return Err(ctx.error(ErrorKind::InvalidFieldAccessExpression, span)); diff --git a/src/parser/types.rs b/src/parser/types.rs index d0142e48b..ed1b73497 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -281,7 +281,9 @@ pub enum TyKind { /// A boolean (`true` or `false`). Bool, - // Tuple(Vec), + + /// A tuple is a data structure which contains many types + Tuple(Vec), // Bool, // U8, // U16, @@ -314,6 +316,7 @@ impl TyKind { /// - If `no_generic_allowed` is `true`, the function returns `false`. /// - If `no_generic_allowed` is `false`, the function compares the element types. /// - For `Custom` types, it compares the `module` and `name` fields for equality. + /// - For tuples, it matches type of each element i.e `self[i] == expected[i]` for every i /// - For other types, it uses a basic equality check. pub fn match_expected(&self, expected: &TyKind, no_generic_allowed: bool) -> bool { match (self, expected) { @@ -339,6 +342,19 @@ impl TyKind { ) => module == expected_module && name == expected_name, (TyKind::String { .. }, TyKind::String { .. }) => true, (x, y) if x == y => true, + (TyKind::Tuple(lhs), TyKind::Tuple(rhs)) => { + // if length does not match then they are of different type + if lhs.len() == rhs.len() { + let match_items = lhs + .iter() + .zip(rhs) + .filter(|&(l, r)| l.match_expected(r, no_generic_allowed)) + .count(); + lhs.len() == match_items + } else { + false + } + } _ => false, } } @@ -361,6 +377,12 @@ impl TyKind { generics.extend(sym.extract_generics()); } TyKind::String { .. } => (), + // for the time when (([Field;N])) + TyKind::Tuple(typs) => { + for ty in typs { + generics.extend(ty.extract_generics()); + } + } } generics @@ -401,6 +423,17 @@ impl Display for TyKind { TyKind::Bool => write!(f, "Bool"), TyKind::GenericSizedArray(ty, size) => write!(f, "[{}; {}]", ty, size), TyKind::String(s) => write!(f, "String({})", s), + TyKind::Tuple(types) => { + write!( + f, + "({})", + types + .iter() + .map(|ty| ty.to_string()) + .collect::>() + .join(",") + ) + } } } } @@ -505,6 +538,34 @@ impl Ty { } } + // tuple type return + TokenKind::LeftParen => { + let mut typs = vec![]; + loop { + // parse the type + let ty = Ty::parse(ctx, tokens)?; + typs.push(ty.kind); + + // if next token is RightParen then return the type + let token = tokens.peek(); + match token { + Some(token) => match token.kind { + TokenKind::RightParen => { + tokens.bump(ctx); + return Ok(Ty { + kind: TyKind::Tuple(typs), + span: token.span, + }); + } + // if there is a comma just bump the tokens so we are on the type + TokenKind::Comma => _ = tokens.bump(ctx), + _ => return Err(ctx.error(ErrorKind::InvalidEndOfLine, token.span)), + }, + _ => (), + } + } + } + // unrecognized _ => Err(ctx.error(ErrorKind::InvalidType, token.span)), } @@ -553,6 +614,15 @@ impl FnSig { generics.add(name); } } + // extracts generics from interior of tuple + TyKind::Tuple(typs) => { + for ty in typs { + let extracted = ty.extract_generics(); + for name in extracted { + generics.add(name); + } + } + } _ => (), } } @@ -608,6 +678,7 @@ impl FnSig { /// Either: /// - `const NN: Field` or `[[Field; NN]; MM]` /// - `[Field; cst]`, where cst is a constant variable. We also monomorphize generic array with a constant var as its size. + /// - `([Field; cst])` when tuple type returns a generic pub fn require_monomorphization(&self) -> bool { let has_arg_cst = self .arguments @@ -632,6 +703,7 @@ impl FnSig { self.has_constant(ty) } TyKind::Array(ty, _) => self.has_constant(ty), + TyKind::Tuple(typs) => typs.iter().any(|ty| self.has_constant(ty)), _ => false, } } @@ -906,6 +978,15 @@ impl FnArg { generics.insert(name); } } + // extract generics for inner type + TyKind::Tuple(typs) => { + for ty in typs { + let extracted = self.typ.kind.extract_generics(); + for name in extracted { + generics.insert(name); + } + } + } _ => (), } diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index 701d0260d..2c16e62fe 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -138,6 +138,32 @@ fn assert_eq_values( TyKind::GenericSizedArray(_, _) => { unreachable!("GenericSizedArray should be monomorphized") } + + TyKind::String(_) => todo!("String is not implemented yet"), + + TyKind::Tuple(typs) => { + let mut offset = 0; + for ty in typs { + let element_size = compiler.size_of(ty); + let mut element_comparisions = assert_eq_values( + compiler, + &VarInfo::new( + Var::new(lhs_info.var.range(offset, element_size).to_vec(), span), + false, + Some(ty.clone()), + ), + &VarInfo::new( + Var::new(rhs_info.var.range(offset, element_size).to_vec(), span), + false, + Some(ty.clone()), + ), + ty, + span, + ); + comparisons.append(&mut element_comparisions); + offset += element_size; + } + } } comparisons diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 28bd32141..cff75961a 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -281,16 +281,16 @@ impl TypeChecker { name.value.clone() } - // `array[idx] = ` - ExprKind::ArrayAccess { array, idx } => { - // get variable behind array - let array_node = self - .compute_type(array, typed_fn_env)? - .expect("type-checker bug: array access on an empty var"); - - array_node + // `array[idx] = ` or `tuple[idx] = rhs` + ExprKind::ArrayOrTupleAccess { container, idx } => { + // get variable behind container + let cotainer_node = self + .compute_type(container, typed_fn_env)? + .expect("type-checker bug: array or tuple access on an empty var"); + + cotainer_node .var_name - .expect("anonymous array access cannot be mutated") + .expect("anonymous array or tuple access cannot be mutated") } // `struct.field = ` @@ -464,13 +464,16 @@ impl TypeChecker { } } - ExprKind::ArrayAccess { array, idx } => { + ExprKind::ArrayOrTupleAccess { container, idx } => { // get type of lhs - let typ = self.compute_type(array, typed_fn_env)?.unwrap(); + let typ = self.compute_type(container, typed_fn_env)?.unwrap(); - // check that it is an array - if !matches!(typ.typ, TyKind::Array(..) | TyKind::GenericSizedArray(..)) { - Err(self.error(ErrorKind::ArrayAccessOnNonArray, expr.span))? + // check that it is an array or tuple + if !matches!( + typ.typ, + TyKind::Array(..) | TyKind::GenericSizedArray(..) | TyKind::Tuple(..) + ) { + Err(self.error(ErrorKind::AccessOnNonCollection, expr.span))? } // check that expression is a bigint @@ -484,6 +487,19 @@ impl TypeChecker { let el_typ = match typ.typ { TyKind::Array(typkind, _) => *typkind, TyKind::GenericSizedArray(typkind, _) => *typkind, + TyKind::Tuple(typs) => match &idx.kind { + ExprKind::BigUInt(index) => { + let idx = index.to_usize().unwrap(); + if idx >= typs.len() { + return Err(self.error( + ErrorKind::TupleIndexOutofBounds(idx, typs.len()), + expr.span, + )); + } + typs[idx].clone() + } + _ => return Err(self.error(ErrorKind::ExpectedConstant, expr.span)), + }, _ => Err(self.error(ErrorKind::UnexpectedError("not an array"), expr.span))?, }; @@ -518,6 +534,20 @@ impl TypeChecker { let res = ExprTyInfo::new_anon(TyKind::Array(Box::new(tykind), len)); Some(res) } + ExprKind::TupleDeclaration(items) => { + // restricting tuple len as array len + let _: u32 = items.len().try_into().expect("tuple too large"); + let typs: Vec = items + .iter() + .map(|item| { + self.compute_type(item, typed_fn_env) + .unwrap() + .expect("expected some val") + .typ + }) + .collect(); + Some(ExprTyInfo::new_anon(TyKind::Tuple(typs))) + } ExprKind::IfElse { cond, then_, else_ } => { // cond can only be a boolean @@ -536,7 +566,7 @@ impl TypeChecker { &then_.kind, ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } ) { return Err(self.error(ErrorKind::IfElseInvalidIfBranch(), then_.span)); } @@ -545,7 +575,7 @@ impl TypeChecker { &else_.kind, ExprKind::Variable { .. } | ExprKind::FieldAccess { .. } - | ExprKind::ArrayAccess { .. } + | ExprKind::ArrayOrTupleAccess { .. } ) { return Err(self.error(ErrorKind::IfElseInvalidElseBranch(), else_.span)); } diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 2b405c477..64497716b 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -488,6 +488,7 @@ impl TypeChecker { TyKind::Field { constant: false } | TyKind::Custom { .. } | TyKind::Array(_, _) + | TyKind::Tuple(_) | TyKind::Bool => { typed_fn_env.store_type( "public_output".to_string(), diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 7d3ef9b28..e59f4bae3 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -158,8 +158,21 @@ pub fn log_string_type( // Array Some(TyKind::Array(b, s)) => { - let (output, remaining) = - log_array_type(backend, &var.var.cvars, b, *s, witness, typed, span).unwrap(); + let mut typs = Vec::with_capacity(*s as usize); + for _ in 0..(*s) { + typs.push((**b).clone()); + } + let (output, remaining) = log_array_or_tuple_type( + backend, + &var.var.cvars, + &typs[..], + *s, + witness, + typed, + span, + false, + ) + .unwrap(); assert!(remaining.is_empty()); Ok(output) } @@ -190,6 +203,24 @@ pub fn log_string_type( unreachable!("GenericSizedArray should be monomorphized") } Some(TyKind::String(_)) => todo!("String cannot be in circuit yet"), + + Some(TyKind::Tuple(typs)) => { + println!("{:?}", typs); + let len = typs.len(); + let (output, remaining) = log_array_or_tuple_type( + backend, + &var.var.cvars, + &typs, + len as u32, + witness, + typed, + span, + true, + ) + .unwrap(); + assert!(remaining.is_empty()); + Ok(output) + } None => { return Err(Error::new( "log", @@ -203,76 +234,72 @@ pub fn log_string_type( replace_all(&re, str, replacer) } -pub fn log_array_type( +pub fn log_array_or_tuple_type( backend: &B, var_info_var: &[ConstOrCell], - base_type: &TyKind, + typs: &[TyKind], size: u32, witness: &mut WitnessEnv, typed: &Mast, span: &Span, + is_tuple: bool, ) -> Result<(String, Vec>)> { - match base_type { - TyKind::Field { .. } => { - let values: Vec = var_info_var - .iter() - .take(size as usize) - .map(|cvar| match cvar { - ConstOrCell::Const(cst) => cst.pretty(), - ConstOrCell::Cell(cell) => backend.compute_var(witness, cell).unwrap().pretty(), - }) - .collect(); - - let remaining = var_info_var[size as usize..].to_vec(); - Ok((format!("[{}]", values.join(", ")), remaining)) - } + let mut remaining = var_info_var.to_vec(); + let mut nested_result = Vec::new(); - TyKind::Bool => { - let values: Vec = var_info_var - .iter() - .take(size as usize) - .map(|cvar| match cvar { + for i in 0..size { + let base_type = &typs[i as usize]; + let (chunk_result, new_remaining) = match base_type { + TyKind::Field { .. } => { + let value = match &remaining[0] { + ConstOrCell::Const(cst) => cst.pretty(), + ConstOrCell::Cell(cell) => { + let val = backend.compute_var(witness, cell).unwrap(); + val.pretty() + } + }; + (value, remaining[1..].to_vec()) + } + // Bool + TyKind::Bool => { + let value = match &remaining[0] { ConstOrCell::Const(cst) => { let val = *cst == B::Field::one(); val.to_string() } ConstOrCell::Cell(cell) => { - let val = backend.compute_var(witness, cell).unwrap() == B::Field::one(); + let val = backend.compute_var(witness, cell)? == B::Field::one(); val.to_string() } - }) - .collect(); - - let remaining = var_info_var[size as usize..].to_vec(); - Ok((format!("[{}]", values.join(", ")), remaining)) - } - - TyKind::Array(inner_type, inner_size) => { - let mut nested_result = Vec::new(); - let mut remaining = var_info_var.to_vec(); - for _ in 0..size { - let (chunk_result, new_remaining) = log_array_type( + }; + (value, remaining[1..].to_vec()) + } + TyKind::Array(inner_type, inner_size) => { + let mut vec_inner_type = Vec::with_capacity(remaining.len()); + for _ in 0..remaining.len() { + vec_inner_type.push((**inner_type).clone()); + } + let is_tuple = match **inner_type { + TyKind::Tuple(_) => true, + _ => false, + }; + log_array_or_tuple_type( backend, &remaining, - inner_type, + &vec_inner_type[..], *inner_size, witness, typed, span, - )?; - nested_result.push(chunk_result); - remaining = new_remaining; + is_tuple, + )? } - Ok((format!("[{}]", nested_result.join(", ")), remaining)) - } - TyKind::Custom { - module, - name: struct_name, - } => { - let mut nested_result = Vec::new(); - let mut remaining = var_info_var.to_vec(); - for _ in 0..size { + // Custom types + TyKind::Custom { + module, + name: struct_name, + } => { let mut string_vec = Vec::new(); let (output, new_remaining) = log_custom_type( backend, @@ -284,17 +311,37 @@ pub fn log_array_type( span, &mut string_vec, )?; - nested_result.push(format!("{}{}", struct_name, output)); - remaining = new_remaining; + (format!("{}{}", struct_name, output), new_remaining) } - Ok((format!("[{}]", nested_result.join(", ")), remaining)) - } - TyKind::GenericSizedArray(_, _) => { - unreachable!("GenericSizedArray should be monomorphized") - } + // GenericSizedArray + TyKind::GenericSizedArray(_, _) => { + unreachable!("GenericSizedArray should be monomorphized") + } + TyKind::String(_) => todo!("String cannot be in circuit yet"), + + TyKind::Tuple(inner_typs) => { + let inner_size = inner_typs.len(); + log_array_or_tuple_type( + backend, + &remaining, + &inner_typs, + inner_size as u32, + witness, + typed, + span, + true, + )? + } + }; + nested_result.push(chunk_result); + remaining = new_remaining; + } - TyKind::String(_) => todo!("String type cannot be used in circuits!"), + if is_tuple { + Ok((format!("({})", nested_result.join(",")), remaining)) + } else { + Ok((format!("[{}]", nested_result.join(",")), remaining)) } } pub fn log_custom_type( @@ -348,8 +395,20 @@ pub fn log_custom_type( }, TyKind::Array(b, s) => { - let (output, new_remaining) = - log_array_type(backend, &remaining, b, *s, witness, typed, span)?; + let len = remaining.len(); + let mut typs: Vec = Vec::with_capacity(len); + typs.push((**b).clone()); + + let (output, new_remaining) = log_array_or_tuple_type( + backend, + &remaining, + &typs[..], + *s, + witness, + typed, + span, + false, + )?; string_vec.push(format!("{field_name}: {}", output)); remaining = new_remaining; } @@ -379,6 +438,15 @@ pub fn log_custom_type( TyKind::String(s) => { todo!("String cannot be a type for customs it is only for logging") } + TyKind::Tuple(typs) => { + let len = typs.len(); + let (output, new_remaining) = log_array_or_tuple_type( + backend, &remaining, &typs, len as u32, witness, typed, span, true, + ) + .unwrap(); + string_vec.push(format!("{field_name}: {}", output)); + remaining = new_remaining; + } } } From 1f4844c7d066a0faa3eff93a936b39bcca5d547b Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Tue, 21 Jan 2025 09:10:18 +0700 Subject: [PATCH 3/4] support subtraction operation in array symbolic value (#251) * support subtraction in symbolic value --- src/mast/mod.rs | 1 + src/parser/types.rs | 5 ++++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 16830d83d..0d9421780 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -414,6 +414,7 @@ impl Symbolic { } Symbolic::Generic(g) => gens.get(&g.value), Symbolic::Add(a, b) => a.eval(gens, tast) + b.eval(gens, tast), + Symbolic::Sub(a, b) => a.eval(gens, tast) - b.eval(gens, tast), Symbolic::Mul(a, b) => a.eval(gens, tast) * b.eval(gens, tast), } } diff --git a/src/parser/types.rs b/src/parser/types.rs index ed1b73497..0e7262d60 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -188,6 +188,7 @@ pub enum Symbolic { /// Generic parameter Generic(Ident), Add(Box, Box), + Sub(Box, Box), Mul(Box, Box), } @@ -198,6 +199,7 @@ impl Display for Symbolic { Symbolic::Constant(ident) => write!(f, "{}", ident.value), Symbolic::Generic(ident) => write!(f, "{}", ident.value), Symbolic::Add(lhs, rhs) => write!(f, "{} + {}", lhs, rhs), + Symbolic::Sub(lhs, rhs) => write!(f, "{} - {}", lhs, rhs), Symbolic::Mul(lhs, rhs) => write!(f, "{} * {}", lhs, rhs), } } @@ -219,7 +221,7 @@ impl Symbolic { Symbolic::Generic(ident) => { generics.insert(ident.value.clone()); } - Symbolic::Add(lhs, rhs) | Symbolic::Mul(lhs, rhs) => { + Symbolic::Add(lhs, rhs) | Symbolic::Mul(lhs, rhs) | Symbolic::Sub(lhs, rhs) => { generics.extend(lhs.extract_generics()); generics.extend(rhs.extract_generics()); } @@ -251,6 +253,7 @@ impl Symbolic { // no protected flags are needed, as this is based on expression nodes which already ordered the operations match op { Op2::Addition => Ok(Symbolic::Add(Box::new(lhs), Box::new(rhs?))), + Op2::Subtraction => Ok(Symbolic::Sub(Box::new(lhs), Box::new(rhs?))), Op2::Multiplication => Ok(Symbolic::Mul(Box::new(lhs), Box::new(rhs?))), _ => Err(Error::new( "mast", From a9a83b8d8cfa14f7391a3ed840908a09e14e931e Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Tue, 21 Jan 2025 17:45:57 +0700 Subject: [PATCH 4/4] =?UTF-8?q?fix:=20instead=20of=20folding=20expr=20for?= =?UTF-8?q?=20constant=20values,=20we=20keep=20the=20expr=20as=E2=80=A6=20?= =?UTF-8?q?(#252)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix: instead of folding expr for constant values, we keep the expr as it is * fix: avoid re-monomorphized a function name --- examples/functions.no | 2 ++ src/error.rs | 3 +++ src/mast/mod.rs | 40 +++++++++++++++++++++++++++++++--------- src/parser/types.rs | 6 ++++++ 4 files changed, 42 insertions(+), 9 deletions(-) diff --git a/examples/functions.no b/examples/functions.no index d69bf196a..2a66c2cc0 100644 --- a/examples/functions.no +++ b/examples/functions.no @@ -10,6 +10,8 @@ fn main(pub one: Field) { let four = add(one, 3); assert_eq(four, 4); + // double() should not be folded to return 8 + // the asm test will catch the missing constraint if it is folded let eight = double(4); assert_eq(eight, double(four)); } diff --git a/src/error.rs b/src/error.rs index 9ac60c9a8..9d086fc2e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -372,6 +372,9 @@ pub enum ErrorKind { #[error("division by zero")] DivisionByZero, + #[error("lhs `{0}` is less than rhs `{1}`")] + NegativeLhsLessThanRhs(String, String), + #[error("Not enough variables provided to fill placeholders in the formatted string")] InsufficientVariables, } diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 0d9421780..e2c6e0dcf 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -827,20 +827,42 @@ fn monomorphize_expr( let ExprMonoInfo { expr: rhs_expr, .. } = rhs_mono; // fold constants - let cst = match (&lhs_expr.kind, &rhs_expr.kind) { - (ExprKind::BigUInt(lhs), ExprKind::BigUInt(rhs)) => match op { - Op2::Addition => Some(lhs + rhs), - Op2::Subtraction => Some(lhs - rhs), - Op2::Multiplication => Some(lhs * rhs), - Op2::Division => Some(lhs / rhs), - _ => None, - }, + let cst = match (&lhs_mono.constant, &rhs_mono.constant) { + (Some(PropagatedConstant::Single(lhs)), Some(PropagatedConstant::Single(rhs))) => { + match op { + Op2::Addition => Some(lhs + rhs), + Op2::Subtraction => { + if lhs < rhs { + // throw error + return Err(error( + ErrorKind::NegativeLhsLessThanRhs( + lhs.to_string(), + rhs.to_string(), + ), + expr.span, + )); + } + Some(lhs - rhs) + } + Op2::Multiplication => Some(lhs * rhs), + Op2::Division => Some(lhs / rhs), + _ => None, + } + } _ => None, }; match cst { Some(v) => { - let mexpr = expr.to_mast(ctx, &ExprKind::BigUInt(v.clone())); + let mexpr = expr.to_mast( + ctx, + &ExprKind::BinaryOp { + op: op.clone(), + protected: *protected, + lhs: Box::new(lhs_expr), + rhs: Box::new(rhs_expr), + }, + ); ExprMonoInfo::new(mexpr, typ, Some(PropagatedConstant::from(v))) } diff --git a/src/parser/types.rs b/src/parser/types.rs index 0e7262d60..9a3ee98ef 100644 --- a/src/parser/types.rs +++ b/src/parser/types.rs @@ -716,6 +716,12 @@ impl FnSig { pub fn monomorphized_name(&self) -> Ident { let mut name = self.name.clone(); + // check if it contains # in the name + if name.value.contains('#') { + // if so, then it is already monomorphized + return name; + } + if self.require_monomorphization() { let mut generics = self.generics.parameters.iter().collect::>(); generics.sort_by(|a, b| a.0.cmp(b.0));