From 4596f2b5d77e08d34724f589890396634903f960 Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Thu, 23 Jan 2025 06:53:06 +0700 Subject: [PATCH 1/3] fix: invalid folded mutable constant value (#250) --- src/negative_tests.rs | 23 +++++++++++++++++++++++ src/type_checker/checker.rs | 4 ++++ src/type_checker/fn_env.rs | 22 ++++++++++++++++++++++ 3 files changed, 49 insertions(+) diff --git a/src/negative_tests.rs b/src/negative_tests.rs index dd557cf71..438015292 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -472,6 +472,29 @@ fn test_generic_custom_type_mismatched() { )); } +#[test] +fn test_generic_mutated_cst_var_in_loop() { + let code = r#" + fn gen(const LEN: Field) -> [Field; LEN] { + return [0; LEN]; + } + + fn main(pub xx: Field) { + let mut loopvar = 1; + for ii in 0..3 { + loopvar = loopvar + 1; + } + let arr = gen(loopvar); + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::ArgumentTypeMismatch(..) + )); +} + #[test] fn test_array_bounds() { let code = r#" diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index cff75961a..0c7ef41c7 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -261,6 +261,10 @@ impl TypeChecker { .compute_type(lhs, typed_fn_env)? .expect("type-checker bug: lhs access on an empty var"); + if let Some(var_name) = &lhs_node.var_name { + typed_fn_env.invalidate_cst_var(var_name); + } + // todo: check and update the const field type for other cases // lhs can be a local variable or a path to an array let lhs_name = match &lhs.kind { diff --git a/src/type_checker/fn_env.rs b/src/type_checker/fn_env.rs index a8540cfd6..1bbf25113 100644 --- a/src/type_checker/fn_env.rs +++ b/src/type_checker/fn_env.rs @@ -110,6 +110,28 @@ impl TypedFnEnv { self.current_scope >= prefix_scope } + /// Because we don't support forloop unrolling, + /// we should invalidate the constant value behind a mutable variable, which is used in a forloop. + /// ```ignore + /// let mut pow2 = 1; + /// for ii in 0..LEN { + /// pow2 = pow2 + pow2; + /// } + /// ``` + /// Instead of folding the constant value to the mutable variable in this case, + /// the actual value will be calculated during synthesizer phase. + pub fn invalidate_cst_var(&mut self, ident: &str) { + // only applies to the variables in the parent scopes + // remove the constant value + if let Some((scope, info)) = self.vars.get_mut(ident) { + if scope < &mut self.current_scope + && matches!(info.typ, TyKind::Field { constant: true }) + { + info.typ = TyKind::Field { constant: false }; + } + } + } + /// Since currently we don't support unrolling, the generic function calls are assumed to target a same instance. /// Each loop iteration should instantiate generic function calls with the same parameters. /// This assumption requires a few type checking rules to forbid the cases that needs unrolling. From bb87a88f88c1bc3e933e50d5ea96053e549528c8 Mon Sep 17 00:00:00 2001 From: Kata Choi Date: Thu, 23 Jan 2025 06:54:00 +0700 Subject: [PATCH 2/3] support pub attribute (#263) * support pub attribute for struct fields --- examples/assignment.no | 2 +- examples/hint.no | 4 ++-- examples/types.no | 4 ++-- examples/types_array.no | 4 ++-- examples/types_array_output.no | 4 ++-- src/circuit_writer/ir.rs | 2 +- src/circuit_writer/writer.rs | 4 ++-- src/error.rs | 3 +++ src/inputs.rs | 2 +- src/mast/mod.rs | 6 +++--- src/name_resolution/context.rs | 2 +- src/negative_tests.rs | 25 ++++++++++++++++++++-- src/parser/mod.rs | 2 ++ src/parser/structs.rs | 38 ++++++++++++++++++++++++++++++---- src/stdlib/builtins.rs | 2 +- src/stdlib/native/int/lib.no | 20 +++++++++++++++++- src/tests/modules.rs | 2 +- src/tests/stdlib/uints/mod.rs | 2 +- src/type_checker/checker.rs | 38 +++++++++++++++++++++++++++------- src/type_checker/fn_env.rs | 23 +++++++++++++++----- src/type_checker/mod.rs | 6 +++--- src/utils/mod.rs | 2 +- 22 files changed, 153 insertions(+), 44 deletions(-) diff --git a/examples/assignment.no b/examples/assignment.no index b82be55cc..6d966a131 100644 --- a/examples/assignment.no +++ b/examples/assignment.no @@ -1,5 +1,5 @@ struct Thing { - xx: Field, + pub xx: Field, } fn try_to_mutate(thing: Thing) { diff --git a/examples/hint.no b/examples/hint.no index e8ac1bc2d..bea559acf 100644 --- a/examples/hint.no +++ b/examples/hint.no @@ -1,6 +1,6 @@ struct Thing { - xx: Field, - yy: Field, + pub xx: Field, + pub yy: Field, } hint fn mul(lhs: Field, rhs: Field) -> Field { diff --git a/examples/types.no b/examples/types.no index 7aa792d6d..d219d3f3c 100644 --- a/examples/types.no +++ b/examples/types.no @@ -1,6 +1,6 @@ struct Thing { - xx: Field, - yy: Field, + pub xx: Field, + pub yy: Field, } fn main(pub xx: Field, pub yy: Field) { diff --git a/examples/types_array.no b/examples/types_array.no index ef08c51bf..e50830164 100644 --- a/examples/types_array.no +++ b/examples/types_array.no @@ -1,6 +1,6 @@ struct Thing { - xx: Field, - yy: Field, + pub xx: Field, + pub yy: Field, } fn main(pub xx: Field, pub yy: Field) { diff --git a/examples/types_array_output.no b/examples/types_array_output.no index 43c75493e..1693be0c7 100644 --- a/examples/types_array_output.no +++ b/examples/types_array_output.no @@ -1,6 +1,6 @@ struct Thing { - xx: Field, - yy: Field, + pub xx: Field, + pub yy: Field, } fn main(pub xx: Field, pub yy: Field) -> [Thing; 2] { diff --git a/src/circuit_writer/ir.rs b/src/circuit_writer/ir.rs index a8aed8fb8..bceb39ca3 100644 --- a/src/circuit_writer/ir.rs +++ b/src/circuit_writer/ir.rs @@ -706,7 +706,7 @@ impl IRWriter { // find range of field let mut start = 0; let mut len = 0; - for (field, field_typ) in &struct_info.fields { + for (field, field_typ, _attribute) in &struct_info.fields { if field == &rhs.value { len = self.size_of(field_typ); break; diff --git a/src/circuit_writer/writer.rs b/src/circuit_writer/writer.rs index 5cbb1815b..f5b8e2c61 100644 --- a/src/circuit_writer/writer.rs +++ b/src/circuit_writer/writer.rs @@ -320,7 +320,7 @@ impl CircuitWriter { .clone(); let mut offset = 0; - for (_field_name, field_typ) in &struct_info.fields { + for (_field_name, field_typ, _attribute) in &struct_info.fields { let len = self.size_of(field_typ); let range = offset..(offset + len); self.constrain_inputs_to_main(&input[range], field_typ, span)?; @@ -501,7 +501,7 @@ impl CircuitWriter { // find range of field let mut start = 0; let mut len = 0; - for (field, field_typ) in &struct_info.fields { + for (field, field_typ, _attribute) in &struct_info.fields { if field == &rhs.value { len = self.size_of(field_typ); break; diff --git a/src/error.rs b/src/error.rs index 9d086fc2e..d70af53e0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -372,6 +372,9 @@ pub enum ErrorKind { #[error("division by zero")] DivisionByZero, + #[error("cannot access private field `{1}` of struct `{0}` from outside its methods.")] + PrivateFieldAccess(String, String), + #[error("lhs `{0}` is less than rhs `{1}`")] NegativeLhsLessThanRhs(String, String), diff --git a/src/inputs.rs b/src/inputs.rs index 9994fa44f..ca81eea55 100644 --- a/src/inputs.rs +++ b/src/inputs.rs @@ -141,7 +141,7 @@ impl CompiledCircuit { // parse each field let mut res = vec![]; - for (field_name, field_ty) in fields { + for (field_name, field_ty, _attribute) in fields { let value = map.remove(field_name).ok_or_else(|| { ParsingError::MissingStructFieldIdent(field_name.to_string()) })?; diff --git a/src/mast/mod.rs b/src/mast/mod.rs index e2c6e0dcf..9b92eac98 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -463,7 +463,7 @@ impl Mast { let mut sum = 0; - for (_, t) in &struct_info.fields { + for (_, t, _) in &struct_info.fields { sum += self.size_of(t); } @@ -567,8 +567,8 @@ fn monomorphize_expr( let typ = struct_info .fields .iter() - .find(|(name, _)| name == &rhs.value) - .map(|(_, typ)| typ.clone()); + .find(|(name, _, _)| name == &rhs.value) + .map(|(_, typ, _)| typ.clone()); let mexpr = expr.to_mast( ctx, diff --git a/src/name_resolution/context.rs b/src/name_resolution/context.rs index 23314adc1..cd0733346 100644 --- a/src/name_resolution/context.rs +++ b/src/name_resolution/context.rs @@ -156,7 +156,7 @@ impl NameResCtx { self.resolve(module, true)?; // we resolve the fully-qualified types of the fields - for (_field_name, field_typ) in fields { + for (_field_name, field_typ, _attribute) in fields { self.resolve_typ_kind(&mut field_typ.kind)?; } diff --git a/src/negative_tests.rs b/src/negative_tests.rs index 438015292..50fb49257 100644 --- a/src/negative_tests.rs +++ b/src/negative_tests.rs @@ -723,7 +723,7 @@ fn test_nonhint_call_with_unsafe() { fn test_no_cst_struct_field_prop() { let code = r#" struct Thing { - val: Field, + pub val: Field, } fn gen(const LEN: Field) -> [Field; LEN] { @@ -748,7 +748,7 @@ fn test_no_cst_struct_field_prop() { fn test_mut_cst_struct_field_prop() { let code = r#" struct Thing { - val: Field, + pub val: Field, } fn gen(const LEN: Field) -> [Field; LEN] { @@ -770,3 +770,24 @@ fn test_mut_cst_struct_field_prop() { ErrorKind::ArgumentTypeMismatch(..) )); } + +#[test] +fn test_private_field_access() { + let code = r#" + struct Room { + pub beds: Field, // public + size: Field // private + } + + fn main(pub beds: Field) { + let room = Room {beds: beds, size: 10}; + room.size = 5; // not allowed + } + "#; + + let res = tast_pass(code).0; + assert!(matches!( + res.unwrap_err().kind, + ErrorKind::PrivateFieldAccess(..) + )); +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index e69ee0ee3..27cbf6dda 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -247,6 +247,8 @@ mod tests { let parsed = StructDef::parse(ctx, tokens); assert!(parsed.is_err()); assert!(parsed.as_ref().err().is_some()); + + println!("{:?}", parsed); match &parsed.as_ref().err().unwrap().kind { ErrorKind::ExpectedTokenNotKeyword(keyword, _) => { assert_eq!(keyword, "pub"); diff --git a/src/parser/structs.rs b/src/parser/structs.rs index 6bfed47a1..2a31bbbdf 100644 --- a/src/parser/structs.rs +++ b/src/parser/structs.rs @@ -3,12 +3,12 @@ use serde::{Deserialize, Serialize}; use crate::{ constants::Span, error::{ErrorKind, Result}, - lexer::{Token, TokenKind, Tokens}, + lexer::{Keyword, Token, TokenKind, Tokens}, syntax::is_type, }; use super::{ - types::{Ident, ModulePath, Ty, TyKind}, + types::{Attribute, AttributeKind, Ident, ModulePath, Ty, TyKind}, Error, ParserCtx, }; @@ -17,7 +17,7 @@ pub struct StructDef { //pub attribute: Attribute, pub module: ModulePath, // name resolution pub name: CustomType, - pub fields: Vec<(Ident, Ty)>, + pub fields: Vec<(Ident, Ty, Option)>, pub span: Span, } @@ -55,6 +55,36 @@ impl StructDef { tokens.bump(ctx); break; } + + // check for pub keyword + // struct Foo { pub a: Field, b: Field } + // ^ + let attribute = if matches!( + tokens.peek(), + Some(Token { + kind: TokenKind::Keyword(Keyword::Pub), + .. + }) + ) { + let token = tokens.bump(ctx).unwrap(); + // next token shouldn't be : + if tokens.peek().unwrap().kind == TokenKind::Colon { + return Err(ctx.error( + ErrorKind::ExpectedTokenNotKeyword( + "pub".to_string(), + TokenKind::Identifier("".to_string()), + ), + token.span, + )); + } + Some(Attribute { + kind: AttributeKind::Pub, + span: token.span, + }) + } else { + None + }; + // struct Foo { a: Field, b: Field } // ^ let field_name = Ident::parse(ctx, tokens)?; @@ -67,7 +97,7 @@ impl StructDef { // ^^^^^ let field_ty = Ty::parse(ctx, tokens)?; span = span.merge_with(field_ty.span); - fields.push((field_name, field_ty)); + fields.push((field_name, field_ty, attribute)); // struct Foo { a: Field, b: Field } // ^ ^ diff --git a/src/stdlib/builtins.rs b/src/stdlib/builtins.rs index 2c16e62fe..82b738214 100644 --- a/src/stdlib/builtins.rs +++ b/src/stdlib/builtins.rs @@ -111,7 +111,7 @@ fn assert_eq_values( // compare each field recursively let mut offset = 0; - for (_, field_type) in &struct_info.fields { + for (_, field_type, _) in &struct_info.fields { let field_size = compiler.size_of(field_type); let mut field_comparisons = assert_eq_values( compiler, diff --git a/src/stdlib/native/int/lib.no b/src/stdlib/native/int/lib.no index 315c252cd..648a80e77 100644 --- a/src/stdlib/native/int/lib.no +++ b/src/stdlib/native/int/lib.no @@ -291,4 +291,22 @@ fn Uint32.mod(self, rhs: Uint32) -> Uint32 { fn Uint64.mod(self, rhs: Uint64) -> Uint64 { let res = self.divmod(rhs); return res[1]; -} \ No newline at end of file +} + +// implement to field +fn Uint8.to_field(self) -> Field { + return self.inner; +} + +fn Uint16.to_field(self) -> Field { + return self.inner; +} + +fn Uint32.to_field(self) -> Field { + return self.inner; +} + +fn Uint64.to_field(self) -> Field { + return self.inner; +} + \ No newline at end of file diff --git a/src/tests/modules.rs b/src/tests/modules.rs index 6517cd4b8..e4b18b35c 100644 --- a/src/tests/modules.rs +++ b/src/tests/modules.rs @@ -31,7 +31,7 @@ use mimoo::liblib; // test a library's type that links to its own type struct Inner { - inner: Field, + pub inner: Field, } struct Lib { diff --git a/src/tests/stdlib/uints/mod.rs b/src/tests/stdlib/uints/mod.rs index efd1ede83..474c017bf 100644 --- a/src/tests/stdlib/uints/mod.rs +++ b/src/tests/stdlib/uints/mod.rs @@ -14,7 +14,7 @@ fn main(pub lhs: Field, rhs: Field) -> Field { let res = lhs_u.{opr}(rhs_u); - return res.inner; + return res.to_field(); } "#; diff --git a/src/type_checker/checker.rs b/src/type_checker/checker.rs index 0c7ef41c7..45f543e31 100644 --- a/src/type_checker/checker.rs +++ b/src/type_checker/checker.rs @@ -11,8 +11,8 @@ use crate::{ imports::FnKind, parser::{ types::{ - is_numeric, FnSig, ForLoopArgument, FunctionDef, ModulePath, Stmt, StmtKind, Symbolic, - Ty, TyKind, + is_numeric, Attribute, AttributeKind, FnSig, ForLoopArgument, FuncOrMethod, + FunctionDef, ModulePath, Stmt, StmtKind, Symbolic, Ty, TyKind, }, CustomType, Expr, ExprKind, Op2, }, @@ -58,7 +58,7 @@ impl FnInfo { #[derive(Deserialize, Serialize, Default, Debug, Clone)] pub struct StructInfo { pub name: String, - pub fields: Vec<(String, TyKind)>, + pub fields: Vec<(String, TyKind, Option)>, pub methods: HashMap, } @@ -119,14 +119,36 @@ impl TypeChecker { .expect("this struct is not defined, or you're trying to access a field of a struct defined in a third-party library (TODO: better error)"); // find field type - let res = struct_info + if let Some((_, field_typ, attribute)) = struct_info .fields .iter() - .find(|(name, _)| name == &rhs.value) - .map(|(_, typ)| typ.clone()); + .find(|(field_name, _, _)| field_name == &rhs.value) + { + // check for the pub attribute + let is_public = matches!( + attribute, + &Some(Attribute { + kind: AttributeKind::Pub, + .. + }) + ); + + // check if we're inside a method of the same struct + let in_method = matches!( + typed_fn_env.current_fn_kind(), + FuncOrMethod::Method(method_struct) if method_struct.name == struct_name + ); - if let Some(res) = res { - Some(ExprTyInfo::new(lhs_node.var_name, res)) + if is_public || in_method { + // allow access + Some(ExprTyInfo::new(lhs_node.var_name, field_typ.clone())) + } else { + // block access + Err(self.error( + ErrorKind::PrivateFieldAccess(struct_name.clone(), rhs.value.clone()), + expr.span, + ))? + } } else { return Err(self.error( ErrorKind::UndefinedField(struct_info.name.clone(), rhs.value.clone()), diff --git a/src/type_checker/fn_env.rs b/src/type_checker/fn_env.rs index 1bbf25113..0e0643318 100644 --- a/src/type_checker/fn_env.rs +++ b/src/type_checker/fn_env.rs @@ -5,7 +5,7 @@ use std::collections::HashMap; use crate::{ constants::Span, error::{Error, ErrorKind, Result}, - parser::types::TyKind, + parser::types::{FuncOrMethod, TyKind}, }; /// Some type information on local variables that we want to track in the [TypedFnEnv] environment. @@ -39,7 +39,7 @@ impl TypeInfo { } /// The environment we use to type check functions. -#[derive(Default, Debug, Clone)] +#[derive(Debug, Clone)] pub struct TypedFnEnv { /// The current nesting level. /// Starting at 0 (top level), and increasing as we go into a block. @@ -55,12 +55,21 @@ pub struct TypedFnEnv { /// Determines if forloop variables are allowed to be accessed. forbid_forloop_scope: bool, + + /// The kind of function we're currently type checking + current_fn_kind: FuncOrMethod, } impl TypedFnEnv { - /// Creates a new TypeEnv - pub fn new() -> Self { - Self::default() + /// Creates a new TypeEnv with the given function kind + pub fn new(fn_kind: &FuncOrMethod) -> Self { + Self { + current_scope: 0, + vars: HashMap::new(), + forloop_scopes: Vec::new(), + forbid_forloop_scope: false, + current_fn_kind: fn_kind.clone(), + } } /// Enters a scoped block. @@ -204,4 +213,8 @@ impl TypedFnEnv { Ok(None) } } + + pub fn current_fn_kind(&self) -> &FuncOrMethod { + &self.current_fn_kind + } } diff --git a/src/type_checker/mod.rs b/src/type_checker/mod.rs index 64497716b..d5ea9f0ad 100644 --- a/src/type_checker/mod.rs +++ b/src/type_checker/mod.rs @@ -298,8 +298,8 @@ impl TypeChecker { let fields: Vec<_> = fields .iter() .map(|field| { - let (name, typ) = field; - (name.value.clone(), typ.kind.clone()) + let (name, typ, attribute) = field; + (name.value.clone(), typ.kind.clone(), attribute.clone()) }) .collect(); @@ -329,7 +329,7 @@ impl TypeChecker { // `fn main() { ... }` RootKind::FunctionDef(function) => { // create a new typed fn environment to type check the function - let mut typed_fn_env = TypedFnEnv::default(); + let mut typed_fn_env = TypedFnEnv::new(&function.sig.kind); // if we're expecting a library, this should not be the main function let is_main = function.is_main(); diff --git a/src/utils/mod.rs b/src/utils/mod.rs index e59f4bae3..40cfde60e 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -366,7 +366,7 @@ pub fn log_custom_type( let mut remaining = var_info_var.to_vec(); - for (field_name, field_typ) in &struct_info.fields { + for (field_name, field_typ, _) in &struct_info.fields { let len = typed.size_of(field_typ); match field_typ { TyKind::Field { .. } => match &remaining[0] { From cd26e9a39bbfde51914a0e6ab48c0c8b41ef8e79 Mon Sep 17 00:00:00 2001 From: Utkarsh Sharma <114555115+0xnullifier@users.noreply.github.com> Date: Thu, 23 Jan 2025 05:25:57 +0530 Subject: [PATCH 3/3] bug-fix: correct type matching in if-else expression (#264) --- src/mast/mod.rs | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/mast/mod.rs b/src/mast/mod.rs index 9b92eac98..4e7120840 100644 --- a/src/mast/mod.rs +++ b/src/mast/mod.rs @@ -1063,7 +1063,17 @@ fn monomorphize_expr( let else_mono = monomorphize_expr(ctx, else_, mono_fn_env)?; // make sure that the type of then_ and else_ match - if then_mono.typ != else_mono.typ { + let is_match = match (&then_mono.typ, &else_mono.typ) { + // generics not allowed as they should have been monomorphized + (Some(then_typ), Some(else_typ)) => then_typ.match_expected(else_typ, true), + _ => Err(Error::new( + "If-Else Monomorphization", + ErrorKind::UnexpectedError("Could not resolve type for the `if-else` branch"), + expr.span, + ))?, + }; + + if !is_match { Err(Error::new( "If-Else Monomorphization", ErrorKind::UnexpectedError(