diff --git a/src/frontend/mod.rs b/src/frontend/mod.rs index a4a49e2a..f564598c 100644 --- a/src/frontend/mod.rs +++ b/src/frontend/mod.rs @@ -1,7 +1,11 @@ //! The frontend includes the user-level abstractions and user-friendly types to define and work //! with Pods. -use std::{collections::HashMap, convert::From, fmt}; +use std::{ + collections::{HashMap, HashSet}, + convert::From, + fmt, +}; use itertools::Itertools; use schemars::JsonSchema; @@ -187,11 +191,11 @@ impl MainPodBuilder { } pub fn pub_op(&mut self, op: Operation) -> Result { - self.op(true, op) + self.op(true, vec![], op) } pub fn priv_op(&mut self, op: Operation) -> Result { - self.op(false, op) + self.op(false, vec![], op) } /// Lower syntactic sugar operation into backend compatible operation. @@ -370,7 +374,17 @@ impl MainPodBuilder { } } - fn op_statement(&mut self, op: Operation) -> Result { + fn op_statement( + &mut self, + wildcard_values: Vec<(usize, Value)>, + op: Operation, + ) -> Result { + // Check for duplicate wildcard value assignments + let mut uniq = HashSet::new(); + if !wildcard_values.iter().all(|(index, _)| uniq.insert(*index)) { + return Err(Error::custom("duplicate wildcard value assignments")); + } + use NativeOperation::*; let st = match op.0 { OperationType::Native(o) => { @@ -553,6 +567,16 @@ impl MainPodBuilder { let mut wildcard_map = vec![Option::None; self.params.max_custom_predicate_wildcards]; + for (index, value) in wildcard_values.into_iter() { + if index >= wildcard_map.len() { + return Err(Error::custom(format!( + "wildcard index {} greater-equal than max {}", + index, + wildcard_map.len() - 1, + ))); + } + wildcard_map[index] = Some(value); + } for (st_tmpl, st) in pred.statements.iter().zip(args.iter()) { let st_args = st.args(); for (st_tmpl_arg, st_arg) in st_tmpl.args.iter().zip(&st_args) { @@ -601,10 +625,16 @@ impl MainPodBuilder { Ok(()) } - fn op(&mut self, public: bool, op: Operation) -> Result { + /// `wildcard_values`: wildcard values to use instead of EMPTY_VALUE for unresolved wildcards + pub fn op( + &mut self, + public: bool, + wildcard_values: Vec<(usize, Value)>, + op: Operation, + ) -> Result { self.add_entries_contains(&op)?; let op = Self::fill_in_aux(Self::lower_op(op)?)?; - let st = self.op_statement(op.clone())?; + let st = self.op_statement(wildcard_values, op.clone())?; self.insert(public, (st, op))?; Ok(self.statements[self.statements.len() - 1].clone()) @@ -772,6 +802,7 @@ pub mod tests { tickets_pod_full_flow, zu_kyc_pod_builder, zu_kyc_pod_request, zu_kyc_sign_dict_builders, EthDosHelper, MOCK_VD_SET, }, + lang::parse, middleware::{ containers::{Array, Set}, Signer as _, Value, @@ -930,7 +961,7 @@ pub mod tests { ], OperationAux::None, ); - let st1 = builder.op(true, op_eq1).unwrap(); + let st1 = builder.op(true, vec![], op_eq1).unwrap(); let op_eq2 = Operation( OperationType::Native(NativeOperation::EqualFromEntries), vec![ @@ -939,14 +970,14 @@ pub mod tests { ], OperationAux::None, ); - let st2 = builder.op(true, op_eq2).unwrap(); + let st2 = builder.op(true, vec![], op_eq2).unwrap(); let op_eq3 = Operation( OperationType::Native(NativeOperation::TransitiveEqualFromStatements), vec![OperationArg::Statement(st1), OperationArg::Statement(st2)], OperationAux::None, ); - builder.op(true, op_eq3).unwrap(); + builder.op(true, vec![], op_eq3).unwrap(); let prover = MockProver {}; let pod = builder.prove(&prover).unwrap(); @@ -1006,7 +1037,7 @@ pub mod tests { let st0 = signed_dict.get_statement("dict").unwrap(); let local = dict!(32, {"key" => "a"})?; let st1 = builder - .op(true, Operation::dict_contains(local, "key", "a")) + .op(true, vec![], Operation::dict_contains(local, "key", "a")) .unwrap(); builder.pub_op(Operation( @@ -1228,6 +1259,7 @@ pub mod tests { let st1 = builder .op( true, + vec![], Operation::dict_contains(local, "known_secret", SecretKey(BigUint::from(123u32))), ) .unwrap(); @@ -1327,4 +1359,40 @@ pub mod tests { let pod = builder.prove(&prover).unwrap(); pod.pod.verify().unwrap(); } + + #[test] + fn test_wildcard_values() -> Result<()> { + let params = Params::default(); + let vd_set = &*MOCK_VD_SET; + + let input = r#" + Test(a, b) = OR( + Equal(a, 5) + Equal(b, 5) + ) + "#; + let batch = parse(input, ¶ms, &[]).unwrap().custom_batch; + let pred_test = batch.predicate_ref_by_name("Test").unwrap(); + + // Try to build with wrong type in 1st arg + let mut builder = MainPodBuilder::new(¶ms, vd_set); + let st0 = builder.priv_op(Operation::eq(5, 5)).unwrap(); + let wildcard_values = vec![(1, Value::from(42))]; + let st = builder + .op( + true, + wildcard_values, + Operation::custom(pred_test, [st0, Statement::None]), + ) + .unwrap(); + let st_args = st.args(); + assert_eq!(StatementArg::Literal(Value::from(5)), st_args[0]); + assert_eq!(StatementArg::Literal(Value::from(42)), st_args[1]); + + let prover = MockProver {}; + let pod = builder.prove(&prover).unwrap(); + pod.pod.verify().unwrap(); + + Ok(()) + } }