diff --git a/air-script/src/cli/transpile.rs b/air-script/src/cli/transpile.rs index cf09e3bad..8e5f9fdd7 100644 --- a/air-script/src/cli/transpile.rs +++ b/air-script/src/cli/transpile.rs @@ -77,6 +77,7 @@ impl Transpile { .chain(mir::passes::AstToMir::new(&diagnostics)) .chain(mir::passes::Inlining::new(&diagnostics)) .chain(mir::passes::Unrolling::new(&diagnostics)) + .chain(mir::passes::BusOpExpand::new(&diagnostics)) .chain(air_ir::passes::MirToAir::new(&diagnostics)); pipeline.run(ast) }) diff --git a/air-script/tests/buses/buses_complex.air b/air-script/tests/buses/buses_complex.air new file mode 100644 index 000000000..0b2eaaa90 --- /dev/null +++ b/air-script/tests/buses/buses_complex.air @@ -0,0 +1,39 @@ +def BusesAir + +trace_columns { + main: [a, b, s1, s2, d], +} + +buses { + unit p, + mult q, +} + +public_inputs { + inputs: [2], +} + +boundary_constraints { + enf p.first = null; + enf q.first = null; + enf p.last = null; + enf q.last = null; + # TODO: to be used when we have support for variable-length public inputs + #enf p.last = inputs; + #enf q.last = inputs; +} + +integrity_constraints { + enf s1^2 = s1; + enf s2^2 = s2; + + p.add(1, a) when s1; + p.rem(1, b) when s2; + + p.add(2, b) when 1 - s1; + p.rem(2, a) when 1 - s2; + + q.add(3, a) when s1; + q.add(3, a) when s1; + q.rem(3, a) for d; +} diff --git a/air-script/tests/buses/buses_complex.masm b/air-script/tests/buses/buses_complex.masm new file mode 100644 index 000000000..0485635e5 --- /dev/null +++ b/air-script/tests/buses/buses_complex.masm @@ -0,0 +1,179 @@ +# Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use. +# +# This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors +# +# Input: [...] +# Output: [...] +proc.cache_z_exp + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, ...] + # Exponentiate z trace_len times + mem_load.4294903307 neg + # => [count, z_1, z_0, ...] where count = -log2(trace_len) + dup.0 neq.0 + while.true + movdn.2 dup.1 dup.1 ext2mul + # => [(e_1, e_0)^n, i, ...] + movup.2 add.1 dup.0 neq.0 + # => [b, i+1, (e_1, e_0)^n, ...] + end # END while + push.0 mem_storew.500000100 # z^trace_len + # => [0, 0, (z_1, z_0)^trace_len, ...] + dropw # Clean stack +end # END PROC cache_z_exp + +# Procedure to compute the exemption points. +# +# Input: [...] +# Output: [g^{-2}, g^{-1}, ...] +proc.get_exemptions_points + mem_load.4294799999 + # => [g, ...] + push.1 swap div + # => [g^{-1}, ...] + dup.0 dup.0 mul + # => [g^{-2}, g^{-1}, ...] +end # END PROC get_exemptions_points + +# Procedure to compute the integrity constraint divisor. +# +# The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))` +# Procedure `cache_z_exp` must have been called prior to this. +# +# Input: [...] +# Output: [divisor_1, divisor_0, ...] +proc.compute_integrity_constraint_divisor + padw mem_loadw.500000100 drop drop # load z^trace_len + # Comments below use zt = `z^trace_len` + # => [zt_1, zt_0, ...] + push.1 push.0 ext2sub + # => [zt_1-1, zt_0-1, ...] + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, zt_1-1, zt_0-1, ...] + exec.get_exemptions_points + # => [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + dup.0 mem_store.500000101 # Save a copy of `g^{trace_len-2} to be used by the boundary divisor + dup.3 dup.3 movup.3 push.0 ext2sub + # => [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + movup.4 movup.4 movup.4 push.0 ext2sub + # => [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...] + ext2mul + # => [denominator_1, denominator_0, zt_1-1, zt_0-1, ...] + ext2div + # => [divisor_1, divisor_0, ...] +end # END PROC compute_integrity_constraint_divisor + +# Procedure to evaluate numerators of all integrity constraints. +# +# All the 2 main and 2 auxiliary constraints are evaluated. +# The result of each evaluation is kept on the stack, with the top of the stack +# containing the evaluations for the auxiliary trace (if any) followed by the main trace. +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation. +# This procedure pushes 4 quadratic extension field elements to the stack +proc.compute_integrity_constraints + # integrity constraint 0 for main + padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul + # integrity constraint 1 for main + padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2mul padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 drop drop ext2mul + # integrity constraint 0 for aux + push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2add ext2add padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.2 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2mul push.1 push.0 push.1 push.0 padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2sub ext2sub ext2add ext2add padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop ext2mul push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900001 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2mul push.1 push.0 padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2sub ext2add ext2add padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.2 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add push.1 push.0 padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2sub ext2mul push.1 push.0 push.1 push.0 padw mem_loadw.4294900003 movdn.3 movdn.3 drop drop ext2sub ext2sub ext2add ext2add padw mem_loadw.4294900072 drop drop ext2mul ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900201 movdn.3 movdn.3 drop drop ext2mul + # integrity constraint 1 for aux + push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900073 movdn.3 movdn.3 drop drop ext2mul push.0 push.0 push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2add push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900002 movdn.3 movdn.3 drop drop ext2mul ext2add ext2add push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900073 drop drop ext2mul push.0 push.0 push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.3 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add padw mem_loadw.4294900000 movdn.3 movdn.3 drop drop padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900004 movdn.3 movdn.3 drop drop ext2mul ext2add ext2add ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900201 drop drop ext2mul +end # END PROC compute_integrity_constraints + +# Procedure to evaluate the boundary constraint numerator for the first row of the auxiliary trace +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# Where: (r_1, r_0) is one quadratic extension field element for each constraint +proc.compute_boundary_constraints_aux_first + # boundary constraint 0 for aux + padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900202 movdn.3 movdn.3 drop drop ext2mul + # boundary constraint 1 for aux + padw mem_loadw.4294900073 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900202 drop drop ext2mul +end # END PROC compute_boundary_constraints_aux_first + +# Procedure to evaluate the boundary constraint numerator for the last row of the auxiliary trace +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# Where: (r_1, r_0) is one quadratic extension field element for each constraint +proc.compute_boundary_constraints_aux_last + # boundary constraint 2 for aux + padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900203 movdn.3 movdn.3 drop drop ext2mul + # boundary constraint 3 for aux + padw mem_loadw.4294900073 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900203 drop drop ext2mul +end # END PROC compute_boundary_constraints_aux_last + +# Procedure to evaluate all integrity constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_integrity_constraints + exec.compute_integrity_constraints + # Numerator of the transition constraint polynomial + ext2add ext2add ext2add ext2add + # Divisor of the transition constraint polynomial + exec.compute_integrity_constraint_divisor + ext2div # divide the numerator by the divisor +end # END PROC evaluate_integrity_constraints + +# Procedure to evaluate all boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_boundary_constraints + exec.compute_boundary_constraints_aux_last + # Accumulate the numerator for segment 1 LastRow + ext2add ext2add + # => [(aux_last1, aux_last0), ...] + # Compute the denominator for domain LastRow + padw mem_loadw.4294903304 drop drop # load z + mem_load.500000101 push.0 ext2sub + # Compute numerator/denominator for last row + ext2div + exec.compute_boundary_constraints_aux_first + # Accumulate the numerator for segment 1 FirstRow + ext2add ext2add + # => [(aux_first1, aux_first0), ...] + # Compute the denominator for domain FirstRow + padw mem_loadw.4294903304 drop drop # load z + push.1 push.0 ext2sub + # Compute numerator/denominator for first row + ext2div + # Add first and last row groups + ext2add +end # END PROC evaluate_boundary_constraints + +# Procedure to evaluate the integrity and boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +export.evaluate_constraints + exec.cache_z_exp + exec.evaluate_integrity_constraints + exec.evaluate_boundary_constraints + ext2add +end # END PROC evaluate_constraints + diff --git a/air-script/tests/buses/buses_complex.rs b/air-script/tests/buses/buses_complex.rs new file mode 100644 index 000000000..91236ffb0 --- /dev/null +++ b/air-script/tests/buses/buses_complex.rs @@ -0,0 +1,96 @@ +use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement}; +use winter_utils::collections::Vec; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + inputs: [Felt; 2], +} + +impl PublicInputs { + pub fn new(inputs: [Felt; 2]) -> Self { + Self { inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + target.write(self.inputs.as_slice()); + } +} + +pub struct BusesAir { + context: AirContext, + inputs: [Felt; 2], +} + +impl BusesAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for BusesAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![TransitionConstraintDegree::new(2), TransitionConstraintDegree::new(2)]; + let aux_degrees = vec![TransitionConstraintDegree::new(3), TransitionConstraintDegree::new(4)]; + let num_main_assertions = 0; + let num_aux_assertions = 4; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, inputs: public_inputs.inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxTraceRandElements) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(0, 0, E::ONE)); + result.push(Assertion::single(1, 0, E::ZERO)); + result.push(Assertion::single(0, self.last_step(), E::ONE)); + result.push(Assertion::single(1, self.last_step(), E::ZERO)); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + result[0] = main_current[2] * main_current[2] - main_current[2]; + result[1] = main_current[3] * main_current[3] - main_current[3]; + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + result[0] = (E::ONE + (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * E::from(main_current[2]) + E::ONE - E::from(main_current[2]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[1]) * aux_rand_elements.get_segment_elements(0)[2]) * (E::ONE - E::from(main_current[2])) + E::ONE - (E::ONE - E::from(main_current[2]))) * aux_current[0] - (E::ONE + (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[1]) * aux_rand_elements.get_segment_elements(0)[2]) * E::from(main_current[3]) + E::ONE - E::from(main_current[3]) + (aux_rand_elements.get_segment_elements(0)[0] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * (E::ONE - E::from(main_current[3])) + E::ONE - (E::ONE - E::from(main_current[3]))) * aux_next[0]; + result[1] = E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * aux_current[1] + E::ZERO + E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * E::from(main_current[2]) + E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * E::from(main_current[2]) - (E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * aux_next[1] + E::ZERO + E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::from(3_u64) * aux_rand_elements.get_segment_elements(0)[1] + E::from(main_current[0]) * aux_rand_elements.get_segment_elements(0)[2]) * E::from(main_current[4])); + } +} \ No newline at end of file diff --git a/air-script/tests/buses/buses_simple.air b/air-script/tests/buses/buses_simple.air new file mode 100644 index 000000000..49aff1c19 --- /dev/null +++ b/air-script/tests/buses/buses_simple.air @@ -0,0 +1,32 @@ +def BusesAir + +trace_columns { + main: [a], +} + +buses { + unit p, + mult q, +} + +public_inputs { + inputs: [2], +} + +boundary_constraints { + enf p.first = null; + enf q.first = null; + enf p.last = null; + enf q.last = null; + # TODO: to be used when we have support for variable-length public inputs + #enf p.last = inputs; + #enf q.last = inputs; +} + +integrity_constraints { + p.add(1) when 1; + p.rem(1) when 1; + q.add(1, 2) when 1; + q.add(1, 2) when 1; + q.rem(1, 2) for 2; +} diff --git a/air-script/tests/buses/buses_simple.masm b/air-script/tests/buses/buses_simple.masm new file mode 100644 index 000000000..44bf95e9a --- /dev/null +++ b/air-script/tests/buses/buses_simple.masm @@ -0,0 +1,171 @@ +# Procedure to efficiently compute the required exponentiations of the out-of-domain point `z` and cache them for later use. +# +# This computes the power of `z` needed to evaluate the periodic polynomials and the constraint divisors +# +# Input: [...] +# Output: [...] +proc.cache_z_exp + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, ...] + # Exponentiate z trace_len times + mem_load.4294903307 neg + # => [count, z_1, z_0, ...] where count = -log2(trace_len) + dup.0 neq.0 + while.true + movdn.2 dup.1 dup.1 ext2mul + # => [(e_1, e_0)^n, i, ...] + movup.2 add.1 dup.0 neq.0 + # => [b, i+1, (e_1, e_0)^n, ...] + end # END while + push.0 mem_storew.500000100 # z^trace_len + # => [0, 0, (z_1, z_0)^trace_len, ...] + dropw # Clean stack +end # END PROC cache_z_exp + +# Procedure to compute the exemption points. +# +# Input: [...] +# Output: [g^{-2}, g^{-1}, ...] +proc.get_exemptions_points + mem_load.4294799999 + # => [g, ...] + push.1 swap div + # => [g^{-1}, ...] + dup.0 dup.0 mul + # => [g^{-2}, g^{-1}, ...] +end # END PROC get_exemptions_points + +# Procedure to compute the integrity constraint divisor. +# +# The divisor is defined as `(z^trace_len - 1) / ((z - g^{trace_len-2}) * (z - g^{trace_len-1}))` +# Procedure `cache_z_exp` must have been called prior to this. +# +# Input: [...] +# Output: [divisor_1, divisor_0, ...] +proc.compute_integrity_constraint_divisor + padw mem_loadw.500000100 drop drop # load z^trace_len + # Comments below use zt = `z^trace_len` + # => [zt_1, zt_0, ...] + push.1 push.0 ext2sub + # => [zt_1-1, zt_0-1, ...] + padw mem_loadw.4294903304 drop drop # load z + # => [z_1, z_0, zt_1-1, zt_0-1, ...] + exec.get_exemptions_points + # => [g^{trace_len-2}, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + dup.0 mem_store.500000101 # Save a copy of `g^{trace_len-2} to be used by the boundary divisor + dup.3 dup.3 movup.3 push.0 ext2sub + # => [e_1, e_0, g^{trace_len-1}, z_1, z_0, zt_1-1, zt_0-1, ...] + movup.4 movup.4 movup.4 push.0 ext2sub + # => [e_3, e_2, e_1, e_0, zt_1-1, zt_0-1, ...] + ext2mul + # => [denominator_1, denominator_0, zt_1-1, zt_0-1, ...] + ext2div + # => [divisor_1, divisor_0, ...] +end # END PROC compute_integrity_constraint_divisor + +# Procedure to evaluate numerators of all integrity constraints. +# +# All the 0 main and 2 auxiliary constraints are evaluated. +# The result of each evaluation is kept on the stack, with the top of the stack +# containing the evaluations for the auxiliary trace (if any) followed by the main trace. +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# where: (r_1, r_0) is the quadratic extension element resulting from the integrity constraint evaluation. +# This procedure pushes 2 quadratic extension field elements to the stack +proc.compute_integrity_constraints + # integrity constraint 0 for aux + push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.1 push.0 ext2mul push.1 push.0 push.1 push.0 ext2sub ext2add ext2add padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop ext2mul push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.1 push.0 ext2mul push.1 push.0 push.1 push.0 ext2sub ext2add ext2add padw mem_loadw.4294900072 drop drop ext2mul ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 movdn.3 movdn.3 drop drop ext2mul + # integrity constraint 1 for aux + push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900073 movdn.3 movdn.3 drop drop ext2mul push.0 push.0 push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul push.1 push.0 ext2mul ext2add push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul push.1 push.0 ext2mul ext2add ext2add push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900073 drop drop ext2mul push.0 push.0 push.1 push.0 padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul padw mem_loadw.4294900150 movdn.3 movdn.3 drop drop push.1 push.0 padw mem_loadw.4294900150 drop drop ext2mul ext2add push.2 push.0 padw mem_loadw.4294900151 movdn.3 movdn.3 drop drop ext2mul ext2add ext2mul push.2 push.0 ext2mul ext2add ext2add ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900200 drop drop ext2mul +end # END PROC compute_integrity_constraints + +# Procedure to evaluate the boundary constraint numerator for the first row of the auxiliary trace +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# Where: (r_1, r_0) is one quadratic extension field element for each constraint +proc.compute_boundary_constraints_aux_first + # boundary constraint 0 for aux + padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900201 movdn.3 movdn.3 drop drop ext2mul + # boundary constraint 1 for aux + padw mem_loadw.4294900073 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900201 drop drop ext2mul +end # END PROC compute_boundary_constraints_aux_first + +# Procedure to evaluate the boundary constraint numerator for the last row of the auxiliary trace +# +# Input: [...] +# Output: [(r_1, r_0)*, ...] +# Where: (r_1, r_0) is one quadratic extension field element for each constraint +proc.compute_boundary_constraints_aux_last + # boundary constraint 2 for aux + padw mem_loadw.4294900072 movdn.3 movdn.3 drop drop push.1 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900202 movdn.3 movdn.3 drop drop ext2mul + # boundary constraint 3 for aux + padw mem_loadw.4294900073 movdn.3 movdn.3 drop drop push.0 push.0 ext2sub + # Multiply by the composition coefficient + padw mem_loadw.4294900202 drop drop ext2mul +end # END PROC compute_boundary_constraints_aux_last + +# Procedure to evaluate all integrity constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_integrity_constraints + exec.compute_integrity_constraints + # Numerator of the transition constraint polynomial + ext2add ext2add + # Divisor of the transition constraint polynomial + exec.compute_integrity_constraint_divisor + ext2div # divide the numerator by the divisor +end # END PROC evaluate_integrity_constraints + +# Procedure to evaluate all boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +# Where: (r_1, r_0) is the final result with the divisor applied +proc.evaluate_boundary_constraints + exec.compute_boundary_constraints_aux_last + # Accumulate the numerator for segment 1 LastRow + ext2add ext2add + # => [(aux_last1, aux_last0), ...] + # Compute the denominator for domain LastRow + padw mem_loadw.4294903304 drop drop # load z + mem_load.500000101 push.0 ext2sub + # Compute numerator/denominator for last row + ext2div + exec.compute_boundary_constraints_aux_first + # Accumulate the numerator for segment 1 FirstRow + ext2add ext2add + # => [(aux_first1, aux_first0), ...] + # Compute the denominator for domain FirstRow + padw mem_loadw.4294903304 drop drop # load z + push.1 push.0 ext2sub + # Compute numerator/denominator for first row + ext2div + # Add first and last row groups + ext2add +end # END PROC evaluate_boundary_constraints + +# Procedure to evaluate the integrity and boundary constraints. +# +# Input: [...] +# Output: [(r_1, r_0), ...] +export.evaluate_constraints + exec.cache_z_exp + exec.evaluate_integrity_constraints + exec.evaluate_boundary_constraints + ext2add +end # END PROC evaluate_constraints + diff --git a/air-script/tests/buses/buses_simple.rs b/air-script/tests/buses/buses_simple.rs new file mode 100644 index 000000000..6148948a0 --- /dev/null +++ b/air-script/tests/buses/buses_simple.rs @@ -0,0 +1,94 @@ +use winter_air::{Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions as WinterProofOptions, TransitionConstraintDegree, TraceInfo}; +use winter_math::fields::f64::BaseElement as Felt; +use winter_math::{ExtensionOf, FieldElement}; +use winter_utils::collections::Vec; +use winter_utils::{ByteWriter, Serializable}; + +pub struct PublicInputs { + inputs: [Felt; 2], +} + +impl PublicInputs { + pub fn new(inputs: [Felt; 2]) -> Self { + Self { inputs } + } +} + +impl Serializable for PublicInputs { + fn write_into(&self, target: &mut W) { + target.write(self.inputs.as_slice()); + } +} + +pub struct BusesAir { + context: AirContext, + inputs: [Felt; 2], +} + +impl BusesAir { + pub fn last_step(&self) -> usize { + self.trace_length() - self.context().num_transition_exemptions() + } +} + +impl Air for BusesAir { + type BaseField = Felt; + type PublicInputs = PublicInputs; + + fn context(&self) -> &AirContext { + &self.context + } + + fn new(trace_info: TraceInfo, public_inputs: PublicInputs, options: WinterProofOptions) -> Self { + let main_degrees = vec![]; + let aux_degrees = vec![TransitionConstraintDegree::new(1), TransitionConstraintDegree::new(1)]; + let num_main_assertions = 0; + let num_aux_assertions = 4; + + let context = AirContext::new_multi_segment( + trace_info, + main_degrees, + aux_degrees, + num_main_assertions, + num_aux_assertions, + options, + ) + .set_num_transition_exemptions(2); + Self { context, inputs: public_inputs.inputs } + } + + fn get_periodic_column_values(&self) -> Vec> { + vec![] + } + + fn get_assertions(&self) -> Vec> { + let mut result = Vec::new(); + result + } + + fn get_aux_assertions>(&self, aux_rand_elements: &AuxTraceRandElements) -> Vec> { + let mut result = Vec::new(); + result.push(Assertion::single(0, 0, E::ONE)); + result.push(Assertion::single(1, 0, E::ZERO)); + result.push(Assertion::single(0, self.last_step(), E::ONE)); + result.push(Assertion::single(1, self.last_step(), E::ZERO)); + result + } + + fn evaluate_transition>(&self, frame: &EvaluationFrame, periodic_values: &[E], result: &mut [E]) { + let main_current = frame.current(); + let main_next = frame.next(); + } + + fn evaluate_aux_transition(&self, main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], aux_rand_elements: &AuxTraceRandElements, result: &mut [E]) + where F: FieldElement, + E: FieldElement + ExtensionOf, + { + let main_current = main_frame.current(); + let main_next = main_frame.next(); + let aux_current = aux_frame.current(); + let aux_next = aux_frame.next(); + result[0] = (E::ONE + (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1]) * E::ONE + E::ONE - E::ONE) * aux_current[0] - (E::ONE + (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1]) * E::ONE + E::ONE - E::ONE) * aux_next[0]; + result[1] = E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * aux_current[1] + E::ZERO + E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * E::ONE + E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * E::ONE - (E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * aux_next[1] + E::ZERO + E::ONE * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * (aux_rand_elements.get_segment_elements(0)[0] + E::ONE * aux_rand_elements.get_segment_elements(0)[1] + E::from(2_u64) * aux_rand_elements.get_segment_elements(0)[2]) * E::from(2_u64)); + } +} \ No newline at end of file diff --git a/air-script/tests/codegen/helpers.rs b/air-script/tests/codegen/helpers.rs index d9cd914aa..9bc51d8cf 100644 --- a/air-script/tests/codegen/helpers.rs +++ b/air-script/tests/codegen/helpers.rs @@ -38,6 +38,7 @@ impl Test { .chain(mir::passes::AstToMir::new(&diagnostics)) .chain(mir::passes::Inlining::new(&diagnostics)) .chain(mir::passes::Unrolling::new(&diagnostics)) + .chain(mir::passes::BusOpExpand::new(&diagnostics)) .chain(air_ir::passes::MirToAir::new(&diagnostics)); pipeline.run(ast) })?, diff --git a/air-script/tests/codegen/masm_with_mir.rs b/air-script/tests/codegen/masm_with_mir.rs index 36959c002..8022ce39f 100644 --- a/air-script/tests/codegen/masm_with_mir.rs +++ b/air-script/tests/codegen/masm_with_mir.rs @@ -64,6 +64,26 @@ fn bitwise() { expected.assert_eq(&generated_masm); } +#[test] +fn buses_simple() { + let generated_masm = Test::new("tests/buses/buses_simple.air".to_string()) + .transpile(Target::Masm, Pipeline::WithMIR) + .unwrap(); + + let expected = expect_file!["../buses/buses_simple.masm"]; + expected.assert_eq(&generated_masm); +} + +#[test] +fn buses_complex() { + let generated_masm = Test::new("tests/buses/buses_complex.air".to_string()) + .transpile(Target::Masm, Pipeline::WithMIR) + .unwrap(); + + let expected = expect_file!["../buses/buses_complex.masm"]; + expected.assert_eq(&generated_masm); +} + #[test] fn constants() { let generated_masm = Test::new("tests/constants/constants.air".to_string()) diff --git a/air-script/tests/codegen/masm_wo_mir.rs b/air-script/tests/codegen/masm_wo_mir.rs index 989b77fa5..45d49347f 100644 --- a/air-script/tests/codegen/masm_wo_mir.rs +++ b/air-script/tests/codegen/masm_wo_mir.rs @@ -24,6 +24,20 @@ fn binary() { expected.assert_eq(&generated_masm); } +#[test] +fn buses_simple() { + Test::new("tests/buses/buses_simple.air".to_string()) + .transpile(Target::Masm, Pipeline::WithoutMIR) + .expect_err("Buses should not be supported in the WithoutMIR pipeline"); +} + +#[test] +fn buses_complex() { + Test::new("tests/buses/buses_complex.air".to_string()) + .transpile(Target::Masm, Pipeline::WithoutMIR) + .expect_err("Buses should not be supported in the WithoutMIR pipeline"); +} + #[test] fn periodic_columns() { let generated_masm = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) diff --git a/air-script/tests/codegen/winterfell_with_mir.rs b/air-script/tests/codegen/winterfell_with_mir.rs index b9ebd84fd..cff412055 100644 --- a/air-script/tests/codegen/winterfell_with_mir.rs +++ b/air-script/tests/codegen/winterfell_with_mir.rs @@ -24,6 +24,26 @@ fn binary() { expected.assert_eq(&generated_air); } +#[test] +fn buses_simple() { + let generated_masm = Test::new("tests/buses/buses_simple.air".to_string()) + .transpile(Target::Winterfell, Pipeline::WithMIR) + .unwrap(); + + let expected = expect_file!["../buses/buses_simple.rs"]; + expected.assert_eq(&generated_masm); +} + +#[test] +fn buses_complex() { + let generated_masm = Test::new("tests/buses/buses_complex.air".to_string()) + .transpile(Target::Winterfell, Pipeline::WithMIR) + .unwrap(); + + let expected = expect_file!["../buses/buses_complex.rs"]; + expected.assert_eq(&generated_masm); +} + #[test] fn periodic_columns() { let generated_air = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) diff --git a/air-script/tests/codegen/winterfell_wo_mir.rs b/air-script/tests/codegen/winterfell_wo_mir.rs index a450ef98d..b3958fc50 100644 --- a/air-script/tests/codegen/winterfell_wo_mir.rs +++ b/air-script/tests/codegen/winterfell_wo_mir.rs @@ -24,6 +24,20 @@ fn binary() { expected.assert_eq(&generated_air); } +#[test] +fn buses_simple() { + Test::new("tests/buses/buses_simple.air".to_string()) + .transpile(Target::Winterfell, Pipeline::WithoutMIR) + .expect_err("Buses should not be supported in the WithoutMIR pipeline"); +} + +#[test] +fn buses_complex() { + Test::new("tests/buses/buses_complex.air".to_string()) + .transpile(Target::Winterfell, Pipeline::WithoutMIR) + .expect_err("Buses should not be supported in the WithoutMIR pipeline"); +} + #[test] fn periodic_columns() { let generated_air = Test::new("tests/periodic_columns/periodic_columns.air".to_string()) diff --git a/air/README.md b/air/README.md index 66e91928e..0388fdf5a 100644 --- a/air/README.md +++ b/air/README.md @@ -19,6 +19,7 @@ let pipeline_with_mir = air_parser::transforms::ConstantPropagation::new(&diagno .chain(mir::passes::AstToMir::new(&diagnostics)) .chain(mir::passes::Inlining::new(&diagnostics)) .chain(mir::passes::Unrolling::new(&diagnostics)) + .chain(mir::passes::BusOpExpand::new(&diagnostics)) .chain(air_ir::passes::MirToAir::new(&diagnostics)); let pipeline_without_mir = air_parser::transforms::ConstantPropagation::new(&diagnostics) diff --git a/air/src/passes/translate_from_ast.rs b/air/src/passes/translate_from_ast.rs index cfcf16f61..7a8738177 100644 --- a/air/src/passes/translate_from_ast.rs +++ b/air/src/passes/translate_from_ast.rs @@ -297,7 +297,8 @@ impl AirBuilder<'_> { } ast::Statement::Enforce(_) | ast::Statement::EnforceIf(_, _) - | ast::Statement::EnforceAll(_) => { + | ast::Statement::EnforceAll(_) + | ast::Statement::BusEnforce(_) => { unreachable!() } } @@ -431,6 +432,13 @@ impl AirBuilder<'_> { ast::Expr::Let(ref let_expr) => self.eval_let_expr(let_expr), // These node types should not exist at this point ast::Expr::Call(_) | ast::Expr::ListComprehension(_) => unreachable!(), + ast::Expr::BusOperation(_) | ast::Expr::Null(_) => { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("buses are not implemented for this Pipeline") + .emit(); + Err(CompileError::Failed) + } } } @@ -447,7 +455,10 @@ impl AirBuilder<'_> { panic!("expected scalar expression to produce scalar value, got: {invalid:?}") } }, - ast::ScalarExpr::Call(_) | ast::ScalarExpr::BoundedSymbolAccess(_) => unreachable!(), + ast::ScalarExpr::Call(_) + | ast::ScalarExpr::BoundedSymbolAccess(_) + | ast::ScalarExpr::BusOperation(_) + | ast::ScalarExpr::Null(_) => unreachable!(), } } diff --git a/air/src/passes/translate_from_mir.rs b/air/src/passes/translate_from_mir.rs index 1d881ee26..3d6ce7649 100644 --- a/air/src/passes/translate_from_mir.rs +++ b/air/src/passes/translate_from_mir.rs @@ -1,4 +1,4 @@ -use std::ops::Deref; +use std::{collections::HashMap, ops::Deref}; use air_parser::{ ast::{self, TraceSegment}, @@ -6,7 +6,7 @@ use air_parser::{ }; use air_pass::Pass; -use miden_diagnostics::{DiagnosticsHandler, Severity, Spanned}; +use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; use mir::ir::{ConstantValue, Link, Mir, MirValue, Op, Parent, SpannedMirValue}; use crate::{graph::NodeIndex, ir::*, CompileError}; @@ -33,7 +33,61 @@ impl Pass for MirToAir<'_> { fn run<'a>(&mut self, mir: Self::Input<'a>) -> Result, Self::Error> { let mut air = Air::new(mir.name); - air.trace_segment_widths = mir.trace_columns.iter().map(|ts| ts.size as u16).collect(); + let buses = mir.constraint_graph().buses.clone(); + + let mut trace_columns = mir.trace_columns.clone(); + + // TODO: When removing aux and rand values, use the following instead + /*if trace_columns.len() != 1 { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("trace columns in mir should only have one segment (main trace)") + .emit(); + return Err(CompileError::Failed); + }*/ + + let mut bus_bindings_map = HashMap::new(); + if !buses.is_empty() { + let existing_aux_segment: Vec<_> = trace_columns + .get(1) + .map(|ts| { + ts.bindings + .iter() + .map(|binding| { + Span::new(binding.span(), (binding.name.unwrap(), binding.size)) + }) + .collect() + }) + .unwrap_or_default(); + + let bus_raw_bindings: Vec<_> = buses + .keys() + .map(|k| Span::new(k.span(), (Identifier::new(k.span(), k.name()), AUX_SEGMENT))) + .collect(); + + // Add buses as `aux` trace columns + let aux_trace_segment = TraceSegment::new( + SourceSpan::default(), + AUX_SEGMENT, + Identifier::new(SourceSpan::default(), Symbol::new(AUX_SEGMENT as u32)), + existing_aux_segment + .into_iter() + .chain(bus_raw_bindings) + .collect(), + ); + for binding in aux_trace_segment.bindings.iter() { + // Also contains non-bus identifiers + bus_bindings_map.insert(binding.name.unwrap(), binding.offset); + } + if trace_columns.len() == 1 { + trace_columns.push(aux_trace_segment); + } else { + let aux = trace_columns.get_mut(1).unwrap(); + *aux = aux_trace_segment; + } + } + + air.trace_segment_widths = trace_columns.iter().map(|ts| ts.size as u16).collect(); air.num_random_values = mir.num_random_values; air.periodic_columns = mir.periodic_columns.clone(); air.public_inputs = mir.public_inputs.clone(); @@ -41,7 +95,8 @@ impl Pass for MirToAir<'_> { let mut builder = AirBuilder { diagnostics: self.diagnostics, air: &mut air, - trace_columns: mir.trace_columns.clone(), + trace_columns: trace_columns.clone(), + bus_bindings_map, }; let graph = mir.constraint_graph(); @@ -62,6 +117,7 @@ struct AirBuilder<'a> { diagnostics: &'a DiagnosticsHandler, air: &'a mut Air, trace_columns: Vec, + bus_bindings_map: HashMap, } /// Helper function to remove the vector wrapper from a scalar operation @@ -196,6 +252,15 @@ impl AirBuilder<'_> { row_offset: trace_access.row_offset, }) } + MirValue::BusAccess(bus_access) => { + let name = bus_access.bus.borrow().deref().name.unwrap(); + let column = self.bus_bindings_map.get(&name).unwrap(); + crate::ir::Value::TraceAccess(crate::ir::TraceAccess { + segment: AUX_SEGMENT, + column: *column, + row_offset: bus_access.row_offset, + }) + } MirValue::PeriodicColumn(periodic_column_access) => { crate::ir::Value::PeriodicColumn(crate::ir::PeriodicColumnAccess { name: periodic_column_access.name, @@ -209,7 +274,7 @@ impl AirBuilder<'_> { }) } MirValue::RandomValue(rv) => crate::ir::Value::RandomValue(*rv), - _ => unreachable!(), + _ => unreachable!("Unexpected MirValue: {:?}", mir_value), }; Ok(self.insert_op(Operation::Value(value))) @@ -243,6 +308,15 @@ impl AirBuilder<'_> { row_offset: offset, }) } + MirValue::BusAccess(bus_access) => { + let name = bus_access.bus.borrow().deref().name.unwrap(); + let column = self.bus_bindings_map.get(&name).unwrap(); + crate::ir::Value::TraceAccess(crate::ir::TraceAccess { + segment: AUX_SEGMENT, + column: *column, + row_offset: offset, + }) + } MirValue::PeriodicColumn(periodic_column_access) => { crate::ir::Value::PeriodicColumn(crate::ir::PeriodicColumnAccess { name: periodic_column_access.name, @@ -331,7 +405,22 @@ impl AirBuilder<'_> { }; (trace_access, lhs_span) } - _ => unreachable!("Expected TraceAccess, received {:?}", value.value), // Raise diag + SpannedMirValue { + value: MirValue::BusAccess(bus_access), + span: lhs_span, + } => { + let bus = bus_access.bus; + let name = bus.borrow().deref().name.unwrap(); + let column = self.bus_bindings_map.get(&name).unwrap(); + // TODO: add offset + let trace_access = + mir::ir::TraceAccess::new(AUX_SEGMENT, *column, bus_access.row_offset); + (trace_access, lhs_span) + } + _ => unreachable!( + "Expected TraceAccess or BusAccess, received {:?}", + value.value + ), // Raise diag }; if let Some(prev) = self.trace_columns[trace_access.segment].mark_constrained( diff --git a/air/src/tests/buses.rs b/air/src/tests/buses.rs new file mode 100644 index 000000000..9b3d170c9 --- /dev/null +++ b/air/src/tests/buses.rs @@ -0,0 +1,151 @@ +use super::{compile, expect_diagnostic, Pipeline}; + +#[test] +fn buses_in_boundary_constraints() { + let source = " + def test + + trace_columns { + main: [a], + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf p.first = null; + enf q.first = null; + enf p.last = null; + enf q.last = null; + # TODO: to be used when we have support for variable-length public inputs + #enf p.last = inputs; + #enf q.last = inputs; + } + + integrity_constraints { + enf a = 0; + }"; + + expect_diagnostic( + source, + "buses are not implemented for this Pipeline", + Pipeline::WithoutMIR, + ); + assert!(compile(source, Pipeline::WithMIR).is_ok()); +} + +#[test] +fn buses_in_integrity_constraints() { + let source = " + def test + + trace_columns { + main: [a], + } + + fn double(a: felt) -> felt { + return a+a; + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf p.first = null; + enf q.first = null; + enf p.last = null; + enf q.last = null; + # TODO: to be used when we have support for variable-length public inputs + #enf p.last = inputs; + #enf q.last = inputs; + } + + integrity_constraints { + p.add(double(a)) when 1; + p.rem(1) when 1; + q.add(1, 2) when 1; + q.add(1, 2) when 1; + q.rem(1, 2) for 2; + }"; + + expect_diagnostic( + source, + "buses are not implemented for this Pipeline", + Pipeline::WithoutMIR, + ); + assert!(compile(source, Pipeline::WithMIR).is_ok()); +} + +// Tests that should return errors +#[test] +fn err_buses_boundaries_to_const() { + let source = " + def test + + trace_columns { + main: [a], + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf p.first = 0; + enf q.last = null; + } + + integrity_constraints { + enf a = 0; + }"; + + expect_diagnostic(source, "error: invalid constraint", Pipeline::WithoutMIR); + expect_diagnostic(source, "error: invalid constraint", Pipeline::WithMIR); +} + +#[test] +fn err_trace_columns_constrained_with_null() { + let source = " + def test + + trace_columns { + main: [a], + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.last = null; + } + + integrity_constraints { + enf a = 0; + }"; + + expect_diagnostic(source, "error: invalid constraint", Pipeline::WithoutMIR); + expect_diagnostic(source, "error: invalid constraint", Pipeline::WithMIR); +} diff --git a/air/src/tests/mod.rs b/air/src/tests/mod.rs index 85cd7127d..c0bed8bc7 100644 --- a/air/src/tests/mod.rs +++ b/air/src/tests/mod.rs @@ -1,5 +1,6 @@ mod access; mod boundary_constraints; +mod buses; mod constant; mod evaluators; mod integrity_constraints; @@ -99,6 +100,7 @@ impl Compiler { .chain(mir::passes::AstToMir::new(&self.diagnostics)) .chain(mir::passes::Inlining::new(&self.diagnostics)) .chain(mir::passes::Unrolling::new(&self.diagnostics)) + .chain(mir::passes::BusOpExpand::new(&self.diagnostics)) .chain(crate::passes::MirToAir::new(&self.diagnostics)); pipeline.run(ast) }), diff --git a/air/src/tests/trace.rs b/air/src/tests/trace.rs index 69e4f3fd7..01c815682 100644 --- a/air/src/tests/trace.rs +++ b/air/src/tests/trace.rs @@ -66,8 +66,16 @@ fn err_bc_column_undeclared() { enf clk' = clk + 1; }"; - expect_diagnostic(source, "this variable is not defined", Pipeline::WithoutMIR); - expect_diagnostic(source, "this variable is not defined", Pipeline::WithMIR); + expect_diagnostic( + source, + "this variable / bus is not defined", + Pipeline::WithoutMIR, + ); + expect_diagnostic( + source, + "this variable / bus is not defined", + Pipeline::WithMIR, + ); } #[test] @@ -87,8 +95,16 @@ fn err_ic_column_undeclared() { enf clk' = clk + 1; }"; - expect_diagnostic(source, "this variable is not defined", Pipeline::WithoutMIR); - expect_diagnostic(source, "this variable is not defined", Pipeline::WithMIR); + expect_diagnostic( + source, + "this variable / bus is not defined", + Pipeline::WithoutMIR, + ); + expect_diagnostic( + source, + "this variable / bus is not defined", + Pipeline::WithMIR, + ); } #[test] diff --git a/air/src/tests/variables.rs b/air/src/tests/variables.rs index 9f50bc75c..b0ba135cc 100644 --- a/air/src/tests/variables.rs +++ b/air/src/tests/variables.rs @@ -198,8 +198,8 @@ fn invalid_matrix_literal_with_leading_vector_binding() { enf clk' = d[0][0]; }"; - expect_diagnostic(source, "expected one of: '\"!\"', '\"(\"', 'decl_ident_ref', 'function_identifier', 'identifier', 'int'", Pipeline::WithoutMIR); - expect_diagnostic(source, "expected one of: '\"!\"', '\"(\"', 'decl_ident_ref', 'function_identifier', 'identifier', 'int'", Pipeline::WithMIR); + expect_diagnostic(source, "expected one of: '\"!\"', '\"(\"', '\"null\"', 'decl_ident_ref', 'function_identifier', 'identifier', 'int'", Pipeline::WithoutMIR); + expect_diagnostic(source, "expected one of: '\"!\"', '\"(\"', '\"null\"', 'decl_ident_ref', 'function_identifier', 'identifier', 'int'", Pipeline::WithMIR); } #[test] @@ -248,8 +248,16 @@ fn invalid_variable_access_before_declaration() { enf clk' = clk + 1; }"; - expect_diagnostic(source, "this variable is not defined", Pipeline::WithoutMIR); - expect_diagnostic(source, "this variable is not defined", Pipeline::WithMIR); + expect_diagnostic( + source, + "this variable / bus is not defined", + Pipeline::WithoutMIR, + ); + expect_diagnostic( + source, + "this variable / bus is not defined", + Pipeline::WithMIR, + ); } #[test] @@ -304,8 +312,16 @@ fn invalid_reference_to_variable_defined_in_other_section() { enf clk' = clk + a; }"; - expect_diagnostic(source, "this variable is not defined", Pipeline::WithoutMIR); - expect_diagnostic(source, "this variable is not defined", Pipeline::WithMIR); + expect_diagnostic( + source, + "this variable / bus is not defined", + Pipeline::WithoutMIR, + ); + expect_diagnostic( + source, + "this variable / bus is not defined", + Pipeline::WithMIR, + ); } #[test] diff --git a/docs/src/SUMMARY.md b/docs/src/SUMMARY.md index b0f7905e0..7f50b7f21 100644 --- a/docs/src/SUMMARY.md +++ b/docs/src/SUMMARY.md @@ -8,6 +8,7 @@ - [Constraint descriptions](./description/constraints.md) - [Variables](./description/variables.md) - [Evaluators](./description/evaluators.md) + - [Buses](./description/buses.md) - [Convenience syntax](./description/convenience.md) - [AirScript Example](./description/example.md) - [Keywords](./description/keywords.md) diff --git a/docs/src/description/buses.md b/docs/src/description/buses.md new file mode 100644 index 000000000..1ff57e8af --- /dev/null +++ b/docs/src/description/buses.md @@ -0,0 +1,63 @@ +# Buses + +A bus is a construct that aims to easily describe specific constraints, and can be for instance useful to communicate data between multiple proofs. + +## Bus types + +## Multiset `unit` + +Multiset-based buses can represent constraints specifying given values have been added or removed from a column, in no specific order. + +## LogUp `mult` + +LogUp-based buses are more complex than multiset buses, and can encode the multiplicity of an element: an element can be added or removed multiple times. + +## Defining buses + +See the [declaring buses](./declarations.md#buses) for more details. + +``` +buses { + unit p, + mult q, +} +``` + +## Bus boundary constraints + +In the boundary constraints section, we can constrain the initial and final state of the bus. Currently, only constraining a bus to be empty (with the `null` keyword) is supported. + +``` +boundary_constraints { + enf p.first = null; + enf p.last = null; +} +``` + +The above example states that the bus `p` should be empty at the beginning and end of the trace. + +## Bus integrity constraints + +In the integrity constraints section, we can add and remove elements (as tuples of felts) to and from a bus. In the following examples, `p` and `q` are respectivelly multiset and logup based buses. + +``` +integrity_constraints { + p.add(a) when s1; + p.rem(a, b) when 1 - s2; +} +``` + +Here, `s1` and `1 - s2` are binary selectors: the element is added or removed when the corresponding selector's value is 1. + +The global resulting constraint on the column of the bus is the following: `p ′ ⋅ ( ( α 0 + α 1 ⋅ a + α 2 ⋅ b ) ⋅ ( 1 − s2 ) + s2 ) = p ⋅ ( ( α 0 + α 1 ⋅ a ) ⋅ s1 + 1 − s1 ))`, where `α i` corresponds to the i-th random value provided by the verifier. + +``` +integrity_constraints { + q.rem(e, f, g) when s + q.add(a, b, c) for d +} +``` + +Similarly to the previous example elements can be added or removed from `q` with binary selectors. However, as it is a LogUp-based bus, it is also possible to add and remove elements with a given scalar multiplicity with the `for` keyword (here, `d` does not have to be binary). + +The global resulting constraint on the column of the bus is the following: `q ′ + s ( α 0 + α 1 ⋅ e + α 2 ⋅ f + α 3 ⋅ g ) = q + d ( α 0 + α 1 ⋅ a + α 2 ⋅ b + α 3 ⋅ c )`, where `α i` corresponds to the i-th random value provided by the verifier. diff --git a/docs/src/description/declarations.md b/docs/src/description/declarations.md index 1afa10b4d..254c5497f 100644 --- a/docs/src/description/declarations.md +++ b/docs/src/description/declarations.md @@ -90,6 +90,23 @@ Periodic columns can be referenced by [integrity constraints](./constraints.md#i When constraints are evaluated, these periodic values always refer to the value of the column in the current row. For example, when evaluating an integrity constraint such as `enf k0 * a = 0`, `k0` would be evaluated as `0` in rows `0`, `1`, `2` of the trace and as `1` in row `3`, and then the cycle would repeat. Attempting to refer to the "next" row of a periodic column, such as by `k0'`, is invalid and will cause a `ParseError`. + +## Buses (`buses`) + +A `buses` section contains declarations for buses used in the description and evaluation of integrity constraints. + + +The following is an example of a valid `buses` source section: + +``` +buses { + unit p, + mult q, +} +``` + +In the above example, we declare two buses: `p` of type `unit`, and `q` of type `mult`. They respectively correspond to a multiset-based bus and a LogUp-based bus, that expand to different constraints. More information on bus types can be found in the [buses](./buses.md) section. + ## Random values (`random_values`) A `random_values` section contains declarations for random values provided by the verifier. Random values can be accessed by the named identifier for the whole array or by named bindings to single or grouped random values within the array. diff --git a/docs/src/description/example.md b/docs/src/description/example.md index b7b4bc7ac..724c43820 100644 --- a/docs/src/description/example.md +++ b/docs/src/description/example.md @@ -19,6 +19,10 @@ periodic_columns { k0: [1, 1, 1, 1, 1, 1, 1, 0], } +buses { + mult: q, +} + boundary_constraints { # define boundary constraints against the main trace at the first row of the trace. enf a.first = stack_inputs[0]; @@ -32,6 +36,9 @@ boundary_constraints { # set the first row of the auxiliary column p to 1 enf p.first = 1; + + # set the bus q to be initially empty + enf q.first = null; } integrity_constraints { @@ -49,5 +56,8 @@ integrity_constraints { # the auxiliary column contains the product of values of c offset by a random value. enf p' = p * (c + $rand[0]); + + # add p to the q bus when s = 1 + q.add(p) when s; } ``` diff --git a/docs/src/description/keywords.md b/docs/src/description/keywords.md index 5f5464262..ae65ba476 100644 --- a/docs/src/description/keywords.md +++ b/docs/src/description/keywords.md @@ -3,8 +3,8 @@ AirScript defines the following keywords: - `boundary_constraints`: used to declare the source section where the [boundary constraints are described](./constraints.md#boundary_constraints). - - `first`: used to access the value of a trace column at the first row of the trace. _It may only be used when defining boundary constraints._ - - `last`: used to access the value of a trace column at the last row of the trace. _It may only be used when defining boundary constraints._ + - `first`: used to access the value of a trace column / bus at the first row of the trace. _It may only be used when defining boundary constraints._ + - `last`: used to access the value of a trace column / bus at the last row of the trace. _It may only be used when defining boundary constraints._ - `case`: used to declare arms of [conditional constraints](./convenience.md#conditional-constraints). - `const`: used to declare [constants](./declarations.md#constant-constant). - `def`: used to [define the name](./organization.md#root-module) of a root AirScript module. diff --git a/docs/src/description/main.md b/docs/src/description/main.md index cd3d032e0..ff9fae332 100644 --- a/docs/src/description/main.md +++ b/docs/src/description/main.md @@ -6,6 +6,7 @@ - [Constraint descriptions](./constraints.md) - [Variables](./variables.md) - [Evaluators](./evaluators.md) +- [Buses](./buses.md) - [Convenience syntax](./convenience.md) - [AirScript Example](./example.md) - [Keywords](./keywords.md) diff --git a/docs/src/description/organization.md b/docs/src/description/organization.md index c04b18dd7..7dae8f34e 100644 --- a/docs/src/description/organization.md +++ b/docs/src/description/organization.md @@ -12,6 +12,7 @@ All modules must start with a module name declaration followed by a set of sourc | [trace columns](./declarations.md#execution-trace-trace_columns) | required | not allowed | | [public inputs](./declarations.md#public-inputs-public_inputs) | required | not allowed | | [periodic columns](./declarations.md#periodic-columns-periodic_columns) | optional | optional | +| [buses](./declarations.md#buses-buses) | optional | optional | | [random values](./declarations.md#random-values-random_values) | optional | not allowed | | [boundary constraints](./constraints.md#boundary-constraints-boundary_constraints) | required | not allowed | | [integrity constraints](./constraints.md#integrity-constraints-integrity_constraints) | required | not allowed | diff --git a/docs/src/introduction.md b/docs/src/introduction.md index 9f42578b9..71cbef955 100644 --- a/docs/src/introduction.md +++ b/docs/src/introduction.md @@ -16,6 +16,8 @@ AirScript includes the following features: - **Random Values**: Users can define random values provided by the verifier (e.g. `alphas: [x, y[14], z],` or `rand: [16],`) +- **Buses**: Users can declare buses (e.g. `unit p,`) + - **Boundary Constraints**: Users can enforce boundary constraints on main and auxiliary trace columns using public inputs, random values, constants and variables. - **Integrity Constraints**: Users can enforce integrity constraints on main and auxiliary trace columns using trace columns, periodic columns, random values, constants and variables. diff --git a/mir/Cargo.toml b/mir/Cargo.toml index d8b4d88ed..f593c39fc 100644 --- a/mir/Cargo.toml +++ b/mir/Cargo.toml @@ -18,3 +18,4 @@ anyhow = "1.0" miden-diagnostics = "0.1" thiserror = "1.0" derive-ir = { package = "derive-ir", path = "./derive-ir", version = "0.1" } +pretty_assertions = "1.4.1" diff --git a/mir/derive-ir/src/builder.rs b/mir/derive-ir/src/builder.rs index 5a8b42489..855e280d5 100644 --- a/mir/derive-ir/src/builder.rs +++ b/mir/derive-ir/src/builder.rs @@ -429,6 +429,7 @@ fn make_builder_struct<'a>( let builder_struct_name = format_ident!("{}Builder", name); let struct_fields = fields.iter().map(|(_, _, field, _, _, _)| field); let builder_struct = quote! { + #[derive(Debug)] pub struct #builder_struct_name { _builder_state: std::marker::PhantomData, #(#struct_fields),* @@ -446,8 +447,14 @@ fn make_builder_aliases<'a>( no_name: &'a syn::Ident, ) -> (syn::Ident, syn::Ident, proc_macro2::TokenStream) { let (yes, no) = ( - quote! { pub struct #yes_name; }, - quote! { pub struct #no_name; }, + quote! { + #[derive(Debug)] + pub struct #yes_name; + }, + quote! { + #[derive(Debug)] + pub struct #no_name; + }, ); let builder_struct_name = format_ident!("{}Builder", name); let mut alias_names = vec![]; @@ -616,10 +623,12 @@ mod tests { cs: Vec>, d: Link, count: i32, + _singleton: Singleton, _hidden: i32, } }; let expected = quote! { + #[derive(Debug)] pub struct FooBuilder { _builder_state: std::marker::PhantomData, parent: BackLink, @@ -629,7 +638,9 @@ mod tests { d: Option>, count: Option } + #[derive(Debug)] pub struct #y; + #[derive(Debug)] pub struct #n; type FooBuilderState0 = FooBuilder<(#y, #n, #y, #y, #n, #n)>; type FooBuilderState1 = FooBuilder<(#y, #y, #y, #y, #n, #n)>; @@ -913,6 +924,7 @@ mod tests { cs: self.cs.clone(), d: self.d.clone().unwrap(), count: self.count.clone().unwrap(), + _singleton: Default::default(), _hidden: Default::default() } ).into() diff --git a/mir/src/ir/bus.rs b/mir/src/ir/bus.rs new file mode 100644 index 000000000..c6a583862 --- /dev/null +++ b/mir/src/ir/bus.rs @@ -0,0 +1,171 @@ +use std::ops::Deref; + +use air_parser::ast::{self, Identifier}; + +use miden_diagnostics::{SourceSpan, Spanned}; + +use crate::{ + ir::{BackLink, Builder, BusOp, BusOpKind, Link, Op}, + CompileError, +}; + +/// A Mir struct to represent a Bus definition +/// we have 2 cases: +/// +/// - [BusType::Unit]: multiset check +/// +/// these constraints: +/// ```air +/// p.add(a, b) when s +/// p.rem(c, d) when (1 - s) +/// ``` +/// translate to this equation: +/// ```tex +/// p′⋅((α0+α1⋅c+α2⋅d)⋅(1−s)+s)=p⋅((α0+α1⋅a+α2⋅b)⋅s+1−s) +/// ``` +/// with this bus definition: +/// ```ignore +/// Bus { +/// bus_type: BusType::Unit, +/// columns: [a, b, c, d], +/// latches: [s, 1 - s], +/// } +/// ``` +/// with: +/// a, b, c, d, s being [Link] in the graph +/// s, 1 - s being [Link] representing booleans in the graph +/// +/// - [BusType::Mult]: LogUp bus +/// +/// these constraints: +/// ```air +/// q.add(a, b, c) for d +/// q.rem(e, f, g) when s +/// ``` +/// translate to this equation: +/// ```tex +/// q′+s/(α0+α1·e+α2·f+α3·g)=q+d/(α0+α1·a+α2·b+α3·c) +/// ``` +/// with this bus definition: +/// ```ignore +/// Bus { +/// bus_type: BusType::Mult, +/// columns: [a, b, c, e, f, g], +/// latches: [d, s], +/// } +/// ``` +/// with: +/// a, b, c, e, f, g being [Link] in the graph +/// d, s being [Link], s is boolean, d is a number. +#[derive(Default, Clone, Eq, Debug, Spanned)] +pub struct Bus { + /// Identifier of the bus + pub name: Option, + /// Type of bus + pub bus_type: ast::BusType, + /// values stored in the bus + /// colums are joined with randomness (αi) in the bus constraint equation + pub columns: Vec>, + /// selectors denoting when a value is present + pub latches: Vec>, + first: Link, + last: Link, + #[span] + span: SourceSpan, +} + +impl std::hash::Hash for Bus { + fn hash(&self, state: &mut H) { + self.name.hash(state); + self.bus_type.hash(state); + self.columns.hash(state); + self.latches.hash(state); + } +} + +impl PartialEq for Bus { + fn eq(&self, other: &Self) -> bool { + self.name == other.name + && self.bus_type == other.bus_type + && self.columns == other.columns + && self.latches == other.latches + } +} + +impl Bus { + pub fn create(name: Identifier, bus_type: ast::BusType, span: SourceSpan) -> Link { + Bus { + name: Some(name), + bus_type, + span, + ..Default::default() + } + .into() + } + + pub fn set_first(&mut self, first: Link) -> Result<(), CompileError> { + let Op::None(_) = self.first.borrow().deref() else { + return Err(CompileError::Failed); + }; + self.first = first; + Ok(()) + } + + pub fn set_last(&mut self, last: Link) -> Result<(), CompileError> { + let Op::None(_) = self.last.borrow().deref() else { + return Err(CompileError::Failed); + }; + self.last = last; + Ok(()) + } + + pub fn get_first(&self) -> Link { + self.first.clone() + } + + pub fn get_last(&self) -> Link { + self.last.clone() + } + + pub fn get_name(&self) -> Option { + self.name + } +} + +impl Link { + pub fn add(&self, columns: &[Link], latch: Link, span: SourceSpan) -> Link { + self.bus_op(BusOpKind::Add, columns, latch, span) + } + + pub fn rem(&self, columns: &[Link], latch: Link, span: SourceSpan) -> Link { + self.bus_op(BusOpKind::Rem, columns, latch, span) + } + + fn bus_op( + &self, + kind: BusOpKind, + columns: &[Link], + latch: Link, + span: SourceSpan, + ) -> Link { + let mut bus_op = BusOp::builder().bus(self.clone()).kind(kind).span(span); + for column in columns { + bus_op = bus_op.args(column.clone()); + } + let bus_op = bus_op.build(); + let bus_op_ref = bus_op.as_bus_op_mut().unwrap(); + bus_op_ref._latch.update(&latch); + drop(bus_op_ref); + self.borrow_mut().columns.push(bus_op.clone()); + self.borrow_mut().latches.push(latch.clone()); + bus_op + } +} + +impl BackLink { + pub fn get_name(&self) -> Option { + self.to_link() + .map(|l| l.borrow().get_name()) + .unwrap_or_else(|| panic!("Bus was dropped")) + } +} diff --git a/mir/src/ir/graph.rs b/mir/src/ir/graph.rs index 112c1d0b2..b3a5f5803 100644 --- a/mir/src/ir/graph.rs +++ b/mir/src/ir/graph.rs @@ -1,37 +1,37 @@ -use crate::{ - ir::{Evaluator, Function, Link, Op, Root}, - CompileError, -}; +use crate::{ir, CompileError}; use std::{ cell::{Ref, RefMut}, collections::BTreeMap, }; -use air_parser::ast::QualifiedIdentifier; +use air_parser::ast::{Identifier, QualifiedIdentifier}; +use miden_diagnostics::Spanned; /// The constraints graph for the Mir. /// /// We store constraints (boundary and integrity), as well as function and evaluator definitions. /// -#[derive(Debug, Default)] +#[derive(Debug, Default, PartialEq, Eq)] pub struct Graph { - functions: BTreeMap>, - evaluators: BTreeMap>, - pub boundary_constraints_roots: Link>>, - pub integrity_constraints_roots: Link>>, + functions: BTreeMap>, + evaluators: BTreeMap>, + pub boundary_constraints_roots: ir::Link>>, + pub integrity_constraints_roots: ir::Link>>, + pub buses: BTreeMap>, + bus_count: usize, } impl Graph { - pub fn create() -> Link { + pub fn create() -> ir::Link { Graph::default().into() } - /// Inserts a function into the graph, returning an error if the root is not a Function, + /// Inserts a function into the graph, returning an error if the root is not a [ir::Function], /// or if the function already exists (declaration conflict). pub fn insert_function( &mut self, ident: QualifiedIdentifier, - node: Link, + node: ir::Link, ) -> Result<(), CompileError> { if node.as_function().is_none() { return Err(CompileError::Failed); @@ -39,7 +39,7 @@ impl Graph { match self.functions.insert(ident, node) { None => Ok(()), Some(link) => { - if let Root::None(_) = *link.borrow() { + if let ir::Root::None(_) = *link.borrow() { Ok(()) } else { Err(CompileError::Failed) @@ -49,18 +49,21 @@ impl Graph { } /// Queries a given function as a root - pub fn get_function_root(&self, ident: &QualifiedIdentifier) -> Option> { + pub fn get_function_root(&self, ident: &QualifiedIdentifier) -> Option> { self.functions.get(ident).cloned() } - /// Queries a given function as a Function - pub fn get_function(&self, ident: &QualifiedIdentifier) -> Option> { + /// Queries a given function as a [ir::Function] + pub fn get_function(&self, ident: &QualifiedIdentifier) -> Option> { // Unwrap is safe as we ensure the type is correct before inserting self.functions.get(ident).map(|n| n.as_function().unwrap()) } - /// Queries a given function as a mutable Function - pub fn get_function_mut(&mut self, ident: &QualifiedIdentifier) -> Option> { + /// Queries a given function as a mutable [ir::Function] + pub fn get_function_mut( + &mut self, + ident: &QualifiedIdentifier, + ) -> Option> { // Unwrap is safe as we ensure the type is correct before inserting self.functions .get_mut(ident) @@ -68,16 +71,16 @@ impl Graph { } /// Queries all function nodes - pub fn get_function_nodes(&self) -> Vec> { + pub fn get_function_nodes(&self) -> Vec> { self.functions.values().cloned().collect() } - /// Inserts an evaluator into the graph, returning an error if the root is not an Evaluator, + /// Inserts an evaluator into the graph, returning an error if the root is not an [ir::Evaluator], /// or if the evaluator already exists (declaration conflict). pub fn insert_evaluator( &mut self, ident: QualifiedIdentifier, - node: Link, + node: ir::Link, ) -> Result<(), CompileError> { if node.as_evaluator().is_none() { return Err(CompileError::Failed); @@ -85,7 +88,7 @@ impl Graph { match self.evaluators.insert(ident, node) { None => Ok(()), Some(link) => { - if let Root::None(_) = *link.borrow() { + if let ir::Root::None(_) = *link.borrow() { Ok(()) } else { Err(CompileError::Failed) @@ -95,20 +98,23 @@ impl Graph { } /// Queries a given evaluator as a root - pub fn get_evaluator_root(&self, ident: &QualifiedIdentifier) -> Option> { + pub fn get_evaluator_root(&self, ident: &QualifiedIdentifier) -> Option> { self.evaluators.get(ident).cloned() } - /// Queries a given evaluator as a mutable Evaluator - pub fn get_evaluator(&self, ident: &QualifiedIdentifier) -> Option> { + /// Queries a given evaluator as a mutable [ir::Evaluator] + pub fn get_evaluator(&self, ident: &QualifiedIdentifier) -> Option> { // Unwrap is safe as we ensure the type is correct before inserting self.evaluators .get(ident) .map(|n| n.as_evaluator().unwrap()) } - /// Queries a given evaluator as a mutable Evaluator - pub fn get_evaluator_mut(&mut self, ident: &QualifiedIdentifier) -> Option> { + /// Queries a given evaluator as a mutable [ir::Evaluator] + pub fn get_evaluator_mut( + &mut self, + ident: &QualifiedIdentifier, + ) -> Option> { // Unwrap is safe as we ensure the type is correct before inserting self.evaluators .get_mut(ident) @@ -116,12 +122,12 @@ impl Graph { } /// Queries all evaluator nodes - pub fn get_evaluator_nodes(&self) -> Vec> { + pub fn get_evaluator_nodes(&self) -> Vec> { self.evaluators.values().cloned().collect() } /// Inserts a boundary constraint into the graph, if it does not already exist. - pub fn insert_boundary_constraints_root(&mut self, root: Link) { + pub fn insert_boundary_constraints_root(&mut self, root: ir::Link) { if !self.boundary_constraints_roots.borrow().contains(&root) { self.boundary_constraints_roots .borrow_mut() @@ -130,14 +136,14 @@ impl Graph { } /// Removes a boundary constraint from the graph. - pub fn remove_boundary_constraints_root(&mut self, root: Link) { + pub fn remove_boundary_constraints_root(&mut self, root: ir::Link) { self.boundary_constraints_roots .borrow_mut() .retain(|n| *n != root); } /// Inserts an integrity constraint into the graph, if it does not already exist. - pub fn insert_integrity_constraints_root(&mut self, root: Link) { + pub fn insert_integrity_constraints_root(&mut self, root: ir::Link) { if !self.integrity_constraints_roots.borrow().contains(&root) { self.integrity_constraints_roots .borrow_mut() @@ -146,9 +152,44 @@ impl Graph { } /// Removes an integrity constraint from the graph. - pub fn remove_integrity_constraints_root(&mut self, root: Link) { + pub fn remove_integrity_constraints_root(&mut self, root: ir::Link) { self.boundary_constraints_roots .borrow_mut() .retain(|n| *n != root); } + + /// Inserts a bus into the graph, + /// returning an error if the bus already exists (declaration conflict). + pub fn insert_bus( + &mut self, + ident: QualifiedIdentifier, + bus: ir::Link, + ) -> Result<(), CompileError> { + bus.borrow_mut().name = Some(Identifier::new(bus.span(), ident.name())); + self.buses + .insert(ident, bus) + .map_or(Ok(()), |_| Err(CompileError::Failed)) + } + + /// Queries a given bus + /// returning a [ir::Link] if it exists. + pub fn get_bus_link(&self, ident: &QualifiedIdentifier) -> Option> { + self.buses.get(ident).cloned() + } + /// Queries a given bus + /// returning a reference to the bus if it exists. + pub fn get_bus(&self, ident: &QualifiedIdentifier) -> Option> { + self.buses.get(ident).map(|n| n.borrow()) + } + + /// Queries a given bus + /// returning a mutable reference to the bus if it exists. + pub fn get_bus_mut(&mut self, ident: &QualifiedIdentifier) -> Option> { + self.buses.get_mut(ident).map(|n| n.borrow_mut()) + } + + /// Queries all bus nodes + pub fn get_bus_nodes(&self) -> Vec> { + self.buses.values().cloned().collect() + } } diff --git a/mir/src/ir/link.rs b/mir/src/ir/link.rs index 7c8aad2bf..8c3147e7d 100644 --- a/mir/src/ir/link.rs +++ b/mir/src/ir/link.rs @@ -72,6 +72,15 @@ impl PartialEq for Link { impl Eq for Link where T: Eq {} +impl Hash for Link +where + T: Hash, +{ + fn hash(&self, state: &mut H) { + self.link.borrow().hash(state) + } +} + impl From for Link { fn from(value: T) -> Self { Self::new(value) @@ -102,15 +111,6 @@ where } } -impl Hash for Link -where - T: Hash, -{ - fn hash(&self, state: &mut H) { - self.link.borrow().hash(state) - } -} - /// A wrapper around a `Option>>` to allow custom trait implementations. /// Used instead of `Link` where a `Link` would create a cyclIc reference. pub struct BackLink { @@ -197,3 +197,55 @@ where } } } + +/// A wrapper around a [Link] to block recursive implementations of [PartialEq] and [Hash]. +#[derive(Clone)] +pub struct Singleton(pub Option>); + +impl Singleton { + pub fn new(value: Link) -> Self { + Self(Some(value)) + } + pub fn none() -> Self { + Self(None) + } + pub fn to_link(&self) -> Option> { + self.0.clone() + } +} + +impl Debug for Singleton { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.0.fmt(f) + } +} + +impl PartialEq for Singleton { + fn eq(&self, _other: &Self) -> bool { + true + } +} + +impl Eq for Singleton {} + +impl Hash for Singleton { + fn hash(&self, _state: &mut H) {} +} + +impl Default for Singleton { + fn default() -> Self { + Self::none() + } +} + +impl From for Singleton { + fn from(value: T) -> Self { + Self::from(Link::from(value)) + } +} + +impl From> for Singleton { + fn from(value: Link) -> Self { + Self(Some(value)) + } +} diff --git a/mir/src/ir/mir.rs b/mir/src/ir/mir.rs index 5c0442586..0493d3ca4 100644 --- a/mir/src/ir/mir.rs +++ b/mir/src/ir/mir.rs @@ -17,7 +17,7 @@ use super::Graph; /// It is equivalent to an [air_parser::ast::Program], except that it has been /// translated into an algebraic graph representation, on which further analysis, /// optimization, and code generation are performed. -#[derive(Debug, Spanned)] +#[derive(Debug, Spanned, PartialEq, Eq)] pub struct Mir { /// The name of the [air_parser::ast::Program] from which this MIR was derived #[span] diff --git a/mir/src/ir/mod.rs b/mir/src/ir/mod.rs index cc202c23f..98144b5d5 100644 --- a/mir/src/ir/mod.rs +++ b/mir/src/ir/mod.rs @@ -1,19 +1,22 @@ +mod bus; mod graph; mod link; mod mir; mod node; mod nodes; mod owner; +mod utils; pub extern crate derive_ir; +pub use bus::Bus; pub use derive_ir::Builder; pub use graph::Graph; -pub use link::{BackLink, Link}; +pub use link::{BackLink, Link, Singleton}; pub use mir::Mir; pub use node::Node; pub use nodes::*; pub use owner::Owner; - +pub use utils::*; /// A trait for nodes that can have children /// This is used with the Child trait to allow for easy traversal and manipulation of the graph pub trait Parent { diff --git a/mir/src/ir/node.rs b/mir/src/ir/node.rs index 6ade30dee..c1224bf70 100644 --- a/mir/src/ir/node.rs +++ b/mir/src/ir/node.rs @@ -1,7 +1,6 @@ -use crate::ir::{BackLink, Child, Op}; +use crate::ir::{BackLink, Child, Link, Op, Owner, Parent, Root}; use miden_diagnostics::{SourceSpan, Spanned}; -use super::{Link, Owner, Parent, Root}; use std::ops::Deref; /// All the nodes that can be in the MIR Graph @@ -31,6 +30,7 @@ pub enum Node { Vector(BackLink), Matrix(BackLink), Accessor(BackLink), + BusOp(BackLink), Parameter(BackLink), Value(BackLink), None(SourceSpan), @@ -61,6 +61,7 @@ impl PartialEq for Node { (Node::Vector(lhs), Node::Vector(rhs)) => lhs.to_link() == rhs.to_link(), (Node::Matrix(lhs), Node::Matrix(rhs)) => lhs.to_link() == rhs.to_link(), (Node::Accessor(lhs), Node::Accessor(rhs)) => lhs.to_link() == rhs.to_link(), + (Node::BusOp(lhs), Node::BusOp(rhs)) => lhs.to_link() == rhs.to_link(), (Node::Parameter(lhs), Node::Parameter(rhs)) => lhs.to_link() == rhs.to_link(), (Node::Value(lhs), Node::Value(rhs)) => lhs.to_link() == rhs.to_link(), (Node::None(_), Node::None(_)) => true, @@ -72,6 +73,7 @@ impl PartialEq for Node { impl std::hash::Hash for Node { fn hash(&self, state: &mut H) { // We first convert the [BackLink] to an [Option] and hash those + eprintln!("hashing node: {:#?}", self); match self { Node::Function(f) => f.to_link().hash(state), Node::Evaluator(e) => e.to_link().hash(state), @@ -88,6 +90,7 @@ impl std::hash::Hash for Node { Node::Vector(v) => v.to_link().hash(state), Node::Matrix(m) => m.to_link().hash(state), Node::Accessor(a) => a.to_link().hash(state), + Node::BusOp(b) => b.to_link().hash(state), Node::Parameter(p) => p.to_link().hash(state), Node::Value(v) => v.to_link().hash(state), Node::None(_) => (), @@ -114,6 +117,7 @@ impl Parent for Node { Node::Vector(v) => v.children(), Node::Matrix(m) => m.children(), Node::Accessor(a) => a.children(), + Node::BusOp(b) => b.children(), Node::Parameter(_p) => Link::default(), Node::Value(_v) => Link::default(), Node::None(_) => Link::default(), @@ -140,6 +144,7 @@ impl Child for Node { Node::Vector(v) => v.get_parents(), Node::Matrix(m) => m.get_parents(), Node::Accessor(a) => a.get_parents(), + Node::BusOp(b) => b.get_parents(), Node::Parameter(p) => p.get_parents(), Node::Value(v) => v.get_parents(), Node::None(_) => Vec::default(), @@ -162,6 +167,7 @@ impl Child for Node { Node::Vector(v) => v.add_parent(parent), Node::Matrix(m) => m.add_parent(parent), Node::Accessor(a) => a.add_parent(parent), + Node::BusOp(b) => b.add_parent(parent), Node::Parameter(p) => p.add_parent(parent), Node::Value(v) => v.add_parent(parent), Node::None(_) => (), @@ -184,6 +190,7 @@ impl Child for Node { Node::Vector(v) => v.remove_parent(parent), Node::Matrix(m) => m.remove_parent(parent), Node::Accessor(a) => a.remove_parent(parent), + Node::BusOp(b) => b.remove_parent(parent), Node::Parameter(p) => p.remove_parent(parent), Node::Value(v) => v.remove_parent(parent), Node::None(_) => (), @@ -211,6 +218,7 @@ impl Link { Op::Vector(_) => Node::Vector(BackLink::from(op_inner_val)), Op::Matrix(_) => Node::Matrix(BackLink::from(op_inner_val)), Op::Accessor(_) => Node::Accessor(BackLink::from(op_inner_val)), + Op::BusOp(_) => Node::BusOp(BackLink::from(op_inner_val)), Op::Parameter(_) => Node::Parameter(BackLink::from(op_inner_val)), Op::Value(_) => Node::Value(BackLink::from(op_inner_val)), Op::None(span) => Node::None(*span), @@ -265,6 +273,7 @@ impl Link { Node::Vector(_) => None, Node::Matrix(_) => None, Node::Accessor(_) => None, + Node::BusOp(_) => None, Node::Parameter(_) => None, Node::Value(_) => None, Node::None(_) => None, @@ -289,6 +298,7 @@ impl Link { Node::Vector(inner) => inner.to_link(), Node::Matrix(inner) => inner.to_link(), Node::Accessor(inner) => inner.to_link(), + Node::BusOp(inner) => inner.to_link(), Node::Parameter(inner) => inner.to_link(), Node::Value(inner) => inner.to_link(), Node::None(_) => None, diff --git a/mir/src/ir/nodes/op.rs b/mir/src/ir/nodes/op.rs index f76c4b6d6..ec988c5c4 100644 --- a/mir/src/ir/nodes/op.rs +++ b/mir/src/ir/nodes/op.rs @@ -1,6 +1,7 @@ use crate::ir::{ - get_inner, get_inner_mut, Accessor, Add, BackLink, Boundary, Call, Child, Enf, Exp, Fold, For, - If, Link, Matrix, Mul, Node, Owner, Parameter, Parent, Sub, Value, Vector, + get_inner, get_inner_mut, Accessor, Add, BackLink, Boundary, BusOp, Call, Child, ConstantValue, + Enf, Exp, Fold, For, If, Link, Matrix, MirValue, Mul, Node, Owner, Parameter, Parent, + Singleton, SpannedMirValue, Sub, Value, Vector, }; use miden_diagnostics::{SourceSpan, Spanned}; @@ -27,6 +28,7 @@ pub enum Op { Vector(Vector), Matrix(Matrix), Accessor(Accessor), + BusOp(BusOp), Parameter(Parameter), Value(Value), None(SourceSpan), @@ -55,6 +57,7 @@ impl Parent for Op { Op::Vector(v) => v.children(), Op::Matrix(m) => m.children(), Op::Accessor(a) => a.children(), + Op::BusOp(b) => b.children(), Op::Parameter(_) => Link::default(), Op::Value(_) => Link::default(), Op::None(_) => Link::default(), @@ -79,6 +82,7 @@ impl Child for Op { Op::Vector(v) => v.get_parents(), Op::Matrix(m) => m.get_parents(), Op::Accessor(a) => a.get_parents(), + Op::BusOp(b) => b.get_parents(), Op::Parameter(p) => p.get_parents(), Op::Value(v) => v.get_parents(), Op::None(_) => Default::default(), @@ -99,6 +103,7 @@ impl Child for Op { Op::Vector(v) => v.add_parent(parent), Op::Matrix(m) => m.add_parent(parent), Op::Accessor(a) => a.add_parent(parent), + Op::BusOp(b) => b.add_parent(parent), Op::Parameter(p) => p.add_parent(parent), Op::Value(v) => v.add_parent(parent), Op::None(_) => {} @@ -119,6 +124,7 @@ impl Child for Op { Op::Vector(v) => v.remove_parent(parent), Op::Matrix(m) => m.remove_parent(parent), Op::Accessor(a) => a.remove_parent(parent), + Op::BusOp(b) => b.remove_parent(parent), Op::Parameter(p) => p.remove_parent(parent), Op::Value(v) => v.remove_parent(parent), Op::None(_) => {} @@ -144,6 +150,7 @@ impl Link { Op::Vector(v) => format!("Op::Vector@{}({:#?})", self.get_ptr(), v), Op::Matrix(m) => format!("Op::Matrix@{}({:#?})", self.get_ptr(), m), Op::Accessor(a) => format!("Op::Accessor@{}({:#?})", self.get_ptr(), a), + Op::BusOp(b) => format!("Op::BusOp@{}({:#?})", self.get_ptr(), b), Op::Parameter(p) => format!("Op::Parameter@{}({:#?})", self.get_ptr(), p), Op::Value(v) => format!("Op::Value@{}({:#?})", self.get_ptr(), v), Op::None(_) => "Op::None".to_string(), @@ -174,49 +181,52 @@ impl Link { fn update_inner_node(&self, node: &Link) { match self.clone().borrow_mut().deref_mut() { Op::Enf(ref mut enf) => { - enf._node = Some(node.clone()); + enf._node = Singleton::from(node.clone()); } Op::Boundary(ref mut boundary) => { - boundary._node = Some(node.clone()); + boundary._node = Singleton::from(node.clone()); } Op::Add(ref mut add) => { - add._node = Some(node.clone()); + add._node = Singleton::from(node.clone()); } Op::Sub(ref mut sub) => { - sub._node = Some(node.clone()); + sub._node = Singleton::from(node.clone()); } Op::Mul(ref mut mul) => { - mul._node = Some(node.clone()); + mul._node = Singleton::from(node.clone()); } Op::Exp(ref mut exp) => { - exp._node = Some(node.clone()); + exp._node = Singleton::from(node.clone()); } Op::If(ref mut if_op) => { - if_op._node = Some(node.clone()); + if_op._node = Singleton::from(node.clone()); } Op::For(ref mut for_op) => { - for_op._node = Some(node.clone()); + for_op._node = Singleton::from(node.clone()); } Op::Call(ref mut call) => { - call._node = Some(node.clone()); + call._node = Singleton::from(node.clone()); } Op::Fold(ref mut fold) => { - fold._node = Some(node.clone()); + fold._node = Singleton::from(node.clone()); } Op::Vector(ref mut vector) => { - vector._node = Some(node.clone()); + vector._node = Singleton::from(node.clone()); } Op::Matrix(ref mut matrix) => { - matrix._node = Some(node.clone()); + matrix._node = Singleton::from(node.clone()); } Op::Accessor(ref mut accessor) => { - accessor._node = Some(node.clone()); + accessor._node = Singleton::from(node.clone()); + } + Op::BusOp(ref mut bus_op) => { + bus_op._node = Singleton::from(node.clone()); } Op::Parameter(ref mut parameter) => { - parameter._node = Some(node.clone()); + parameter._node = Singleton::from(node.clone()); } Op::Value(ref mut value) => { - value._node = Some(node.clone()); + value._node = Singleton::from(node.clone()); } Op::None(_) => {} } @@ -225,43 +235,46 @@ impl Link { fn update_inner_owner(&self, owner: &Link) { match self.clone().borrow_mut().deref_mut() { Op::Enf(ref mut enf) => { - enf._owner = Some(owner.clone()); + enf._owner = Singleton::from(owner.clone()); } Op::Boundary(ref mut boundary) => { - boundary._owner = Some(owner.clone()); + boundary._owner = Singleton::from(owner.clone()); } Op::Add(ref mut add) => { - add._owner = Some(owner.clone()); + add._owner = Singleton::from(owner.clone()); } Op::Sub(ref mut sub) => { - sub._owner = Some(owner.clone()); + sub._owner = Singleton::from(owner.clone()); } Op::Mul(ref mut mul) => { - mul._owner = Some(owner.clone()); + mul._owner = Singleton::from(owner.clone()); } Op::Exp(ref mut exp) => { - exp._owner = Some(owner.clone()); + exp._owner = Singleton::from(owner.clone()); } Op::If(ref mut if_op) => { - if_op._owner = Some(owner.clone()); + if_op._owner = Singleton::from(owner.clone()); } Op::For(ref mut for_op) => { - for_op._owner = Some(owner.clone()); + for_op._owner = Singleton::from(owner.clone()); } Op::Call(ref mut call) => { - call._owner = Some(owner.clone()); + call._owner = Singleton::from(owner.clone()); } Op::Fold(ref mut fold) => { - fold._owner = Some(owner.clone()); + fold._owner = Singleton::from(owner.clone()); } Op::Vector(ref mut vector) => { - vector._owner = Some(owner.clone()); + vector._owner = Singleton::from(owner.clone()); } Op::Matrix(ref mut matrix) => { - matrix._owner = Some(owner.clone()); + matrix._owner = Singleton::from(owner.clone()); } Op::Accessor(ref mut accessor) => { - accessor._owner = Some(owner.clone()); + accessor._owner = Singleton::from(owner.clone()); + } + Op::BusOp(ref mut bus_op) => { + bus_op._owner = Singleton::from(owner.clone()); } Op::Parameter(ref mut _parameter) => {} Op::Value(ref mut _value) => {} @@ -275,123 +288,147 @@ impl Link { let back: BackLink = self.clone().into(); match self.clone().borrow_mut().deref_mut() { Op::Enf(Enf { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Enf(ref mut enf) => { let node: Link = Node::Enf(back).into(); - enf._node = Some(node.clone()); + enf._node = Singleton::from(node.clone()); node } Op::Boundary(Boundary { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Boundary(ref mut boundary) => { let node: Link = Node::Boundary(back).into(); - boundary._node = Some(node.clone()); + boundary._node = Singleton::from(node.clone()); node } Op::Add(Add { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Add(ref mut add) => { let node: Link = Node::Add(back).into(); - add._node = Some(node.clone()); + add._node = Singleton::from(node.clone()); node } Op::Sub(Sub { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Sub(ref mut sub) => { let node: Link = Node::Sub(back).into(); - sub._node = Some(node.clone()); + sub._node = Singleton::from(node.clone()); node } Op::Mul(Mul { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Mul(ref mut mul) => { let node: Link = Node::Mul(back).into(); - mul._node = Some(node.clone()); + mul._node = Singleton::from(node.clone()); node } Op::Exp(Exp { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Exp(ref mut exp) => { let node: Link = Node::Exp(back).into(); - exp._node = Some(node.clone()); + exp._node = Singleton::from(node.clone()); node } Op::If(If { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::If(ref mut if_op) => { let node: Link = Node::If(back).into(); - if_op._node = Some(node.clone()); + if_op._node = Singleton::from(node.clone()); node } Op::For(For { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::For(ref mut for_op) => { let node: Link = Node::For(back).into(); - for_op._node = Some(node.clone()); + for_op._node = Singleton::from(node.clone()); node } Op::Call(Call { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Call(ref mut call) => { let node: Link = Node::Call(back).into(); - call._node = Some(node.clone()); + call._node = Singleton::from(node.clone()); node } Op::Fold(Fold { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Fold(ref mut fold) => { let node: Link = Node::Fold(back).into(); - fold._node = Some(node.clone()); + fold._node = Singleton::from(node.clone()); node } Op::Vector(Vector { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Vector(ref mut vector) => { let node: Link = Node::Vector(back).into(); - vector._node = Some(node.clone()); + vector._node = Singleton::from(node.clone()); node } Op::Matrix(Matrix { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Matrix(ref mut matrix) => { let node: Link = Node::Matrix(back).into(); - matrix._node = Some(node.clone()); + matrix._node = Singleton::from(node.clone()); node } Op::Accessor(Accessor { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Accessor(ref mut accessor) => { let node: Link = Node::Accessor(back).into(); - accessor._node = Some(node.clone()); + accessor._node = Singleton::from(node.clone()); + node + } + Op::BusOp(BusOp { + _node: Singleton(Some(link)), + .. + }) => link.clone(), + Op::BusOp(ref mut bus_op) => { + let node: Link = Node::BusOp(back).into(); + bus_op._node = Singleton::from(node.clone()); node } Op::Parameter(Parameter { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Parameter(ref mut parameter) => { let node: Link = Node::Parameter(back).into(); - parameter._node = Some(node.clone()); + parameter._node = Singleton::from(node.clone()); node } Op::Value(Value { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Op::Value(ref mut value) => { let node: Link = Node::Value(back).into(); - value._node = Some(node.clone()); + value._node = Singleton::from(node.clone()); node } Op::None(span) => Node::None(*span).into(), @@ -403,108 +440,130 @@ impl Link { let back: BackLink = self.clone().into(); match self.clone().borrow_mut().deref_mut() { Op::Enf(Enf { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Enf(ref mut enf) => { let owner: Link = Owner::Enf(back).into(); - enf._owner = Some(owner.clone()); - enf._owner.clone() + enf._owner = Singleton::from(owner.clone()); + enf._owner.0.clone() } Op::Boundary(Boundary { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Boundary(ref mut boundary) => { let owner: Link = Owner::Boundary(back).into(); - boundary._owner = Some(owner.clone()); - boundary._owner.clone() + boundary._owner = Singleton::from(owner.clone()); + boundary._owner.0.clone() } Op::Add(Add { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Add(ref mut add) => { let owner: Link = Owner::Add(back).into(); - add._owner = Some(owner.clone()); - add._owner.clone() + add._owner = Singleton::from(owner.clone()); + add._owner.0.clone() } Op::Sub(Sub { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Sub(ref mut sub) => { let owner: Link = Owner::Sub(back).into(); - sub._owner = Some(owner.clone()); - sub._owner.clone() + sub._owner = Singleton::from(owner.clone()); + sub._owner.0.clone() } Op::Mul(Mul { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Mul(ref mut mul) => { let owner: Link = Owner::Mul(back).into(); - mul._owner = Some(owner.clone()); - mul._owner.clone() + mul._owner = Singleton::from(owner.clone()); + mul._owner.0.clone() } Op::Exp(Exp { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Exp(ref mut exp) => { let owner: Link = Owner::Exp(back).into(); - exp._owner = Some(owner.clone()); - exp._owner.clone() + exp._owner = Singleton::from(owner.clone()); + exp._owner.0.clone() } Op::If(If { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::If(ref mut if_op) => { let owner: Link = Owner::If(back).into(); - if_op._owner = Some(owner.clone()); - if_op._owner.clone() + if_op._owner = Singleton::from(owner.clone()); + if_op._owner.0.clone() } Op::For(For { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::For(ref mut for_op) => { let owner: Link = Owner::For(back).into(); - for_op._owner = Some(owner.clone()); - for_op._owner.clone() + for_op._owner = Singleton::from(owner.clone()); + for_op._owner.0.clone() } Op::Call(Call { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Call(ref mut call) => { let owner: Link = Owner::Call(back).into(); - call._owner = Some(owner.clone()); - call._owner.clone() + call._owner = Singleton::from(owner.clone()); + call._owner.0.clone() } Op::Fold(Fold { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Fold(ref mut fold) => { let owner: Link = Owner::Fold(back).into(); - fold._owner = Some(owner.clone()); - fold._owner.clone() + fold._owner = Singleton::from(owner.clone()); + fold._owner.0.clone() } Op::Vector(Vector { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Vector(ref mut vector) => { let owner: Link = Owner::Vector(back).into(); - vector._owner = Some(owner.clone()); - vector._owner.clone() + vector._owner = Singleton::from(owner.clone()); + vector._owner.0.clone() } Op::Matrix(Matrix { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Matrix(ref mut matrix) => { let owner: Link = Owner::Matrix(back).into(); - matrix._owner = Some(owner.clone()); - matrix._owner.clone() + matrix._owner = Singleton::from(owner.clone()); + matrix._owner.0.clone() } Op::Accessor(Accessor { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => Some(link.clone()), Op::Accessor(ref mut accessor) => { let owner: Link = Owner::Accessor(back).into(); - accessor._owner = Some(owner.clone()); - accessor._owner.clone() + accessor._owner = Singleton::from(owner.clone()); + accessor._owner.0.clone() + } + Op::BusOp(BusOp { + _owner: Singleton(Some(link)), + .. + }) => Some(link.clone()), + Op::BusOp(ref mut bus_op) => { + let owner: Link = Owner::BusOp(back).into(); + bus_op._owner = Singleton::from(owner.clone()); + bus_op._owner.0.clone() } Op::Parameter(_) => None, Op::Value(_) => None, @@ -719,6 +778,22 @@ impl Link { _ => None, }) } + /// Try getting the current [Op]'s inner [BusOp]. + /// Returns None if the current [Op] is not a [BusOp] or the Rc count is zero + pub fn as_bus_op(&self) -> Option> { + get_inner(self.borrow(), |op| match op { + Op::BusOp(inner) => Some(inner), + _ => None, + }) + } + /// Try getting the current [Op]'s inner [BusOp], borrowing mutably. + /// Returns None if the current [Op] is not a [BusOp] or the Rc count is zero + pub fn as_bus_op_mut(&self) -> Option> { + get_inner_mut(self.borrow_mut(), |op| match op { + Op::BusOp(inner) => Some(inner), + _ => None, + }) + } /// Try getting the current [Op]'s inner [Parameter]. /// Returns None if the current [Op] is not a [Parameter] or the Rc count is zero pub fn as_parameter(&self) -> Option> { @@ -752,3 +827,16 @@ impl Link { }) } } + +impl From for Link { + fn from(value: i64) -> Self { + Op::Value(Value { + value: SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(value as u64)), + ..Default::default() + }, + ..Default::default() + }) + .into() + } +} diff --git a/mir/src/ir/nodes/ops/accessor.rs b/mir/src/ir/nodes/ops/accessor.rs index c5ced0b50..bf5178a78 100644 --- a/mir/src/ir/nodes/ops/accessor.rs +++ b/mir/src/ir/nodes/ops/accessor.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use air_parser::ast::AccessType; use miden_diagnostics::{SourceSpan, Spanned}; use std::hash::Hash; @@ -7,31 +7,17 @@ use std::hash::Hash; /// - access_type: AccessType, which describes for example how to access a given index for a Vector (e.g. `v[0]`) /// - offset: usize, which describes the row offset for a trace column access (e.g. `a'`) /// -#[derive(Hash, Clone, PartialEq, Eq, Debug, Builder, Spanned)] +#[derive(Hash, Clone, PartialEq, Eq, Debug, Builder, Spanned, Default)] #[enum_wrapper(Op)] pub struct Accessor { pub parents: Vec>, pub indexable: Link, pub access_type: AccessType, pub offset: usize, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, -} - -impl Default for Accessor { - fn default() -> Self { - Self { - parents: Vec::default(), - indexable: Link::default(), - access_type: AccessType::Default, - offset: 0, - _node: None, - _owner: None, - span: SourceSpan::default(), - } - } + pub span: SourceSpan, } impl Accessor { diff --git a/mir/src/ir/nodes/ops/add.rs b/mir/src/ir/nodes/ops/add.rs index 804fec34d..6affe1cde 100644 --- a/mir/src/ir/nodes/ops/add.rs +++ b/mir/src/ir/nodes/ops/add.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent the addition of two MIR ops, `lhs` and `rhs` @@ -9,10 +9,10 @@ pub struct Add { pub parents: Vec>, pub lhs: Link, pub rhs: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Add { diff --git a/mir/src/ir/nodes/ops/boundary.rs b/mir/src/ir/nodes/ops/boundary.rs index ff6cb4495..d96cf20b2 100644 --- a/mir/src/ir/nodes/ops/boundary.rs +++ b/mir/src/ir/nodes/ops/boundary.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use air_parser::ast::Boundary as BoundaryKind; use miden_diagnostics::{SourceSpan, Spanned}; use std::hash::Hash; @@ -7,29 +7,16 @@ use std::hash::Hash; /// /// Note: Boundary ops are only valid to describe boundary constraints, not integrity constraints /// -#[derive(Clone, PartialEq, Eq, Debug, Builder, Spanned)] +#[derive(Clone, PartialEq, Default, Eq, Debug, Builder, Spanned)] #[enum_wrapper(Op)] pub struct Boundary { pub parents: Vec>, pub kind: BoundaryKind, pub expr: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, -} - -impl Default for Boundary { - fn default() -> Self { - Self { - parents: Vec::default(), - kind: BoundaryKind::First, - expr: Link::default(), - _node: None, - _owner: None, - span: SourceSpan::default(), - } - } + pub span: SourceSpan, } impl Hash for Boundary { diff --git a/mir/src/ir/nodes/ops/bus_op.rs b/mir/src/ir/nodes/ops/bus_op.rs new file mode 100644 index 000000000..44f98465d --- /dev/null +++ b/mir/src/ir/nodes/ops/bus_op.rs @@ -0,0 +1,82 @@ +use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Parent, Singleton}; +use miden_diagnostics::{SourceSpan, Spanned}; +use std::hash::Hash; + +#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)] +pub enum BusOpKind { + #[default] + Add, + Rem, +} + +#[derive(Default, Clone, Eq, Debug, Spanned, Builder)] +#[enum_wrapper(Op)] +pub struct BusOp { + pub parents: Vec>, + pub bus: BackLink, + pub kind: BusOpKind, + pub args: Vec>, + pub _latch: Link, + pub _node: Singleton, + pub _owner: Singleton, + #[span] + pub span: SourceSpan, +} + +impl Hash for BusOp { + fn hash(&self, state: &mut H) { + self.bus.get_name().hash(state); + self.kind.hash(state); + self.args.hash(state); + self._latch.hash(state); + } +} + +impl PartialEq for BusOp { + fn eq(&self, other: &Self) -> bool { + self.bus.get_name() == other.bus.get_name() + && self.kind == other.kind + && self.args == other.args + && self._latch == other._latch + } +} + +impl BusOp { + pub fn create( + bus: BackLink, + kind: BusOpKind, + args: Vec>, + span: SourceSpan, + ) -> Link { + Op::BusOp(Self { + bus, + kind, + args, + span, + ..Default::default() + }) + .into() + } +} + +impl Parent for BusOp { + type Child = Op; + fn children(&self) -> Link>> { + let mut children = self.args.clone(); + children.push(self._latch.clone()); + children.into() + } +} + +impl Child for BusOp { + type Parent = Owner; + fn get_parents(&self) -> Vec> { + self.parents.clone() + } + fn add_parent(&mut self, parent: Link) { + self.parents.push(parent.into()); + } + fn remove_parent(&mut self, parent: Link) { + self.parents.retain(|p| *p != parent.clone().into()); + } +} diff --git a/mir/src/ir/nodes/ops/call.rs b/mir/src/ir/nodes/ops/call.rs index 8609e74d2..854261dfa 100644 --- a/mir/src/ir/nodes/ops/call.rs +++ b/mir/src/ir/nodes/ops/call.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Root}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Root, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent a call to a given function, a `Root` that represents either a `Function` or an `Evaluator` @@ -14,10 +14,10 @@ pub struct Call { pub function: Link, /// Parent::children only contains the arguments pub arguments: Link>>, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Call { diff --git a/mir/src/ir/nodes/ops/enf.rs b/mir/src/ir/nodes/ops/enf.rs index 4ba32b7a0..6abdeb25c 100644 --- a/mir/src/ir/nodes/ops/enf.rs +++ b/mir/src/ir/nodes/ops/enf.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to enforce that a given MIR op, `expr` equals zero @@ -8,10 +8,10 @@ use miden_diagnostics::{SourceSpan, Spanned}; pub struct Enf { pub parents: Vec>, pub expr: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Enf { diff --git a/mir/src/ir/nodes/ops/exp.rs b/mir/src/ir/nodes/ops/exp.rs index bf6106ae6..b317cc159 100644 --- a/mir/src/ir/nodes/ops/exp.rs +++ b/mir/src/ir/nodes/ops/exp.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent the exponentiation of a MIR op, `lhs` by another, `rhs` @@ -11,10 +11,10 @@ pub struct Exp { pub parents: Vec>, pub lhs: Link, pub rhs: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Exp { diff --git a/mir/src/ir/nodes/ops/fold.rs b/mir/src/ir/nodes/ops/fold.rs index 38c668220..03dad99a6 100644 --- a/mir/src/ir/nodes/ops/fold.rs +++ b/mir/src/ir/nodes/ops/fold.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent folding a given Vector operator according to a given operator and initial value @@ -15,10 +15,10 @@ pub struct Fold { pub iterator: Link, pub operator: FoldOperator, pub initial_value: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } #[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] diff --git a/mir/src/ir/nodes/ops/for_op.rs b/mir/src/ir/nodes/ops/for_op.rs index e5945f9c0..7b67d8b90 100644 --- a/mir/src/ir/nodes/ops/for_op.rs +++ b/mir/src/ir/nodes/ops/for_op.rs @@ -1,6 +1,6 @@ use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent list comprehensions. /// @@ -17,10 +17,10 @@ pub struct For { pub iterators: Link>>, pub expr: Link, pub selector: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl For { diff --git a/mir/src/ir/nodes/ops/if_op.rs b/mir/src/ir/nodes/ops/if_op.rs index cfe5627d9..02ae44da2 100644 --- a/mir/src/ir/nodes/ops/if_op.rs +++ b/mir/src/ir/nodes/ops/if_op.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent conditional constraints @@ -15,10 +15,10 @@ pub struct If { pub condition: Link, pub then_branch: Link, pub else_branch: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl If { diff --git a/mir/src/ir/nodes/ops/matrix.rs b/mir/src/ir/nodes/ops/matrix.rs index c48dd692e..ba539c5f3 100644 --- a/mir/src/ir/nodes/ops/matrix.rs +++ b/mir/src/ir/nodes/ops/matrix.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent a matrix of MIR ops of a given size @@ -10,10 +10,10 @@ pub struct Matrix { pub size: usize, // elements are of type Vector pub elements: Link>>, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Matrix { diff --git a/mir/src/ir/nodes/ops/mod.rs b/mir/src/ir/nodes/ops/mod.rs index 8e6e70ef4..50bcd7587 100644 --- a/mir/src/ir/nodes/ops/mod.rs +++ b/mir/src/ir/nodes/ops/mod.rs @@ -1,6 +1,7 @@ mod accessor; mod add; mod boundary; +mod bus_op; mod call; mod enf; mod exp; @@ -17,6 +18,7 @@ mod vector; pub use accessor::Accessor; pub use add::Add; pub use boundary::Boundary; +pub use bus_op::{BusOp, BusOpKind}; pub use call::Call; pub use enf::Enf; pub use exp::Exp; @@ -28,7 +30,8 @@ pub use mul::Mul; pub use parameter::Parameter; pub use sub::Sub; pub use value::{ - ConstantValue, MirType, MirValue, PeriodicColumnAccess, PublicInputAccess, SpannedMirValue, - TraceAccess, TraceAccessBinding, Value, + BusAccess, ConstantValue, MirType, MirValue, PeriodicColumnAccess, + PublicInputAccess, /*, PublicInputBinding*/ + SpannedMirValue, TraceAccess, TraceAccessBinding, Value, }; pub use vector::Vector; diff --git a/mir/src/ir/nodes/ops/mul.rs b/mir/src/ir/nodes/ops/mul.rs index e4ab87a3f..05954c3af 100644 --- a/mir/src/ir/nodes/ops/mul.rs +++ b/mir/src/ir/nodes/ops/mul.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent the multiplication of two MIR ops, `lhs` and `rhs` @@ -9,10 +9,10 @@ pub struct Mul { pub parents: Vec>, pub lhs: Link, pub rhs: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Mul { diff --git a/mir/src/ir/nodes/ops/parameter.rs b/mir/src/ir/nodes/ops/parameter.rs index 15fb44825..7dbdc6b3b 100644 --- a/mir/src/ir/nodes/ops/parameter.rs +++ b/mir/src/ir/nodes/ops/parameter.rs @@ -1,5 +1,5 @@ use super::MirType; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; use std::hash::{Hash, Hasher}; @@ -15,9 +15,9 @@ pub struct Parameter { pub position: usize, /// The type of the `Parameter` pub ty: MirType, - pub _node: Option>, + pub _node: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Parameter { @@ -27,7 +27,7 @@ impl Parameter { ref_node: BackLink::none(), position, ty, - _node: None, + _node: Singleton::none(), span, }) .into() @@ -50,6 +50,11 @@ impl PartialEq for Parameter { fn eq(&self, other: &Self) -> bool { self.position == other.position && self.ty == other.ty + // TODO: This always returns true. + // fix this by inserting unique ids in the nodes in a + // linearized order in place of the hash. + // See the relationship between [crate::ir::Bus] and [crate::ir::BusOp] + // and their use in [crate::ir::Graph::insert_bus] for an example. && get_hash(&self.ref_node) == get_hash(&other.ref_node) } } @@ -58,6 +63,11 @@ impl Hash for Parameter { fn hash(&self, state: &mut H) { self.position.hash(state); self.ty.hash(state); + // TODO: This always returns true. + // fix this by inserting unique ids in the nodes in a + // linearized order in place of the hash. + // See the relationship between [crate::ir::Bus] and [crate::ir::BusOp] + // and their use in [crate::ir::Graph::insert_bus] for an example. self.ref_node.hash(state); } } diff --git a/mir/src/ir/nodes/ops/sub.rs b/mir/src/ir/nodes/ops/sub.rs index 71148b7fc..99b49a353 100644 --- a/mir/src/ir/nodes/ops/sub.rs +++ b/mir/src/ir/nodes/ops/sub.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent the substraction of two MIR ops, `lhs` and `rhs` @@ -9,10 +9,10 @@ pub struct Sub { pub parents: Vec>, pub lhs: Link, pub rhs: Link, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Sub { diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index 25317c3a2..72c045df4 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -1,7 +1,7 @@ use air_parser::ast::{self, Identifier, QualifiedIdentifier, TraceColumnIndex, TraceSegmentId}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner}; +use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Singleton}; /// A MIR operation to represent a known value, [Value] /// Wraps a [SpannedMirValue] to represent a known value in the [MIR] @@ -11,7 +11,7 @@ pub struct Value { pub parents: Vec>, #[span] pub value: SpannedMirValue, - pub _node: Option>, + pub _node: Singleton, } impl Value { @@ -24,6 +24,18 @@ impl Value { } } +impl From for Value { + fn from(value: i64) -> Self { + Self { + value: SpannedMirValue { + value: MirValue::Constant(ConstantValue::Felt(value as u64)), + span: Default::default(), + }, + ..Default::default() + } + } +} + impl Child for Value { type Parent = Owner; fn get_parents(&self) -> Vec> { @@ -53,12 +65,39 @@ pub enum MirValue { PeriodicColumn(PeriodicColumnAccess), /// A reference to a specific element of a given public input PublicInput(PublicInputAccess), + // TODO: Will be used when handling variable-length public inputs + /*/// A reference to a given public input + PublicInputBinding(PublicInputBinding),*/ /// A reference to the `random_values` array, specifically the element at the given index RandomValue(usize), /// A binding to a set of consecutive trace columns of a given size TraceAccessBinding(TraceAccessBinding), /// A binding to a range of random values RandomValueBinding(RandomValueBinding), + /// A binding to a [Bus]. + BusAccess(BusAccess), + Null, +} + +/// [BusAccess] is like [SymbolAccess], but is used to describe an access to a specific bus. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct BusAccess { + /// The trace segment being accessed + pub bus: Link, + /// The offset from the current row. + /// + /// Defaults to 0, which indicates no offset/the current row. + /// + /// For example, if accessing a trace column with `a'`, where `a` is bound to a single column, + /// the row offset would be `1`, as the `'` modifier indicates the "next" row. + pub row_offset: usize, +} + +impl BusAccess { + /// Creates a new [BusAccess]. + pub const fn new(bus: Link, row_offset: usize) -> Self { + Self { bus, row_offset } + } } #[derive(Debug, Eq, PartialEq, Clone, Hash)] @@ -164,6 +203,20 @@ impl PublicInputAccess { } } +// TODO: Will be used when handling variable-length public inputs +/*/// Represents an access of a [PublicInput], similar in nature to [TraceAccess] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub struct PublicInputBinding { + /// The name of the public input to bind + pub name: Identifier, +} + +impl PublicInputBinding { + pub const fn new(name: Identifier) -> Self { + Self { name } + } +}*/ + impl Default for SpannedMirValue { fn default() -> Self { Self { diff --git a/mir/src/ir/nodes/ops/vector.rs b/mir/src/ir/nodes/ops/vector.rs index bbd2f7366..7bccb34fa 100644 --- a/mir/src/ir/nodes/ops/vector.rs +++ b/mir/src/ir/nodes/ops/vector.rs @@ -1,4 +1,4 @@ -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent}; +use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR operation to represent a vector of MIR ops of a given size @@ -9,10 +9,10 @@ pub struct Vector { pub parents: Vec>, pub size: usize, pub elements: Link>>, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Vector { diff --git a/mir/src/ir/nodes/root.rs b/mir/src/ir/nodes/root.rs index 9e2f3c64e..71ab6a54e 100644 --- a/mir/src/ir/nodes/root.rs +++ b/mir/src/ir/nodes/root.rs @@ -7,6 +7,7 @@ use miden_diagnostics::{SourceSpan, Spanned}; use crate::ir::{ get_inner, get_inner_mut, BackLink, Evaluator, Function, Link, Node, Op, Owner, Parent, + Singleton, }; /// The root nodes of the MIR Graph @@ -58,19 +59,21 @@ impl Link { let back: BackLink = self.clone().into(); match self.borrow_mut().deref_mut() { Root::Function(Function { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Root::Function(ref mut f) => { let node: Link = Node::Function(back).into(); - f._node = Some(node.clone()); + f._node = Singleton::from(node.clone()); node } Root::Evaluator(Evaluator { - _node: Some(link), .. + _node: Singleton(Some(link)), + .. }) => link.clone(), Root::Evaluator(ref mut e) => { let node: Link = Node::Evaluator(back).into(); - e._node = Some(node.clone()); + e._node = Singleton::from(node.clone()); node } Root::None(span) => Node::None(*span).into(), @@ -82,19 +85,21 @@ impl Link { let back: BackLink = self.clone().into(); match self.borrow_mut().deref_mut() { Root::Function(Function { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => link.clone(), Root::Function(ref mut f) => { let owner: Link = Owner::Function(back).into(); - f._owner = Some(owner.clone()); + f._owner = Singleton::from(owner.clone()); owner } Root::Evaluator(Evaluator { - _owner: Some(link), .. + _owner: Singleton(Some(link)), + .. }) => link.clone(), Root::Evaluator(ref mut e) => { let owner: Link = Owner::Evaluator(back).into(); - e._owner = Some(owner.clone()); + e._owner = Singleton::from(owner.clone()); owner } Root::None(span) => Owner::None(*span).into(), diff --git a/mir/src/ir/nodes/roots/evaluator.rs b/mir/src/ir/nodes/roots/evaluator.rs index 295803361..49df90cc9 100644 --- a/mir/src/ir/nodes/roots/evaluator.rs +++ b/mir/src/ir/nodes/roots/evaluator.rs @@ -1,7 +1,6 @@ +use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root}; - /// A MIR Root to represent a Evaluator definition #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] #[enum_wrapper(Root)] @@ -11,10 +10,10 @@ pub struct Evaluator { pub parameters: Vec>>, // Operations contained in the Evaluator pub body: Link>>, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Evaluator { diff --git a/mir/src/ir/nodes/roots/function.rs b/mir/src/ir/nodes/roots/function.rs index 040f26a36..9fa3fa150 100644 --- a/mir/src/ir/nodes/roots/function.rs +++ b/mir/src/ir/nodes/roots/function.rs @@ -1,4 +1,4 @@ -use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root}; +use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; use miden_diagnostics::{SourceSpan, Spanned}; /// A MIR Root to represent a Function definition @@ -11,10 +11,10 @@ pub struct Function { pub return_type: Link, // Operations contained in the function pub body: Link>>, - pub _node: Option>, - pub _owner: Option>, + pub _node: Singleton, + pub _owner: Singleton, #[span] - span: SourceSpan, + pub span: SourceSpan, } impl Function { diff --git a/mir/src/ir/owner.rs b/mir/src/ir/owner.rs index d3736bd00..77edc614c 100644 --- a/mir/src/ir/owner.rs +++ b/mir/src/ir/owner.rs @@ -17,6 +17,7 @@ pub enum Owner { Function(BackLink), Evaluator(BackLink), Accessor(BackLink), + BusOp(BackLink), Boundary(BackLink), Vector(BackLink), Matrix(BackLink), @@ -51,6 +52,7 @@ impl Parent for Owner { Owner::Vector(v) => v.children(), Owner::Matrix(m) => m.children(), Owner::Accessor(a) => a.children(), + Owner::BusOp(b) => b.children(), Owner::None(_) => Link::default(), } } @@ -75,6 +77,7 @@ impl Child for Owner { Owner::Vector(v) => v.get_parents(), Owner::Matrix(m) => m.get_parents(), Owner::Accessor(a) => a.get_parents(), + Owner::BusOp(b) => b.get_parents(), Owner::None(_) => Vec::default(), } } @@ -95,6 +98,7 @@ impl Child for Owner { Owner::Vector(v) => v.add_parent(parent), Owner::Matrix(m) => m.add_parent(parent), Owner::Accessor(a) => a.add_parent(parent), + Owner::BusOp(b) => b.add_parent(parent), Owner::None(_) => (), } } @@ -115,6 +119,7 @@ impl Child for Owner { Owner::Vector(v) => v.remove_parent(parent), Owner::Matrix(m) => m.remove_parent(parent), Owner::Accessor(a) => a.remove_parent(parent), + Owner::BusOp(b) => b.remove_parent(parent), Owner::None(_) => (), } } @@ -139,6 +144,7 @@ impl PartialEq for Owner { (Owner::Vector(lhs), Owner::Vector(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::Matrix(lhs), Owner::Matrix(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::Accessor(lhs), Owner::Accessor(rhs)) => lhs.to_link() == rhs.to_link(), + (Owner::BusOp(lhs), Owner::BusOp(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::None(_), Owner::None(_)) => true, _ => false, } @@ -164,6 +170,7 @@ impl std::hash::Hash for Owner { Owner::Vector(v) => v.to_link().hash(state), Owner::Matrix(m) => m.to_link().hash(state), Owner::Accessor(a) => a.to_link().hash(state), + Owner::BusOp(b) => b.to_link().hash(state), Owner::None(s) => s.hash(state), } } @@ -189,6 +196,7 @@ impl Link { Op::Vector(_) => Owner::Vector(BackLink::from(op_inner_val)), Op::Matrix(_) => Owner::Matrix(BackLink::from(op_inner_val)), Op::Accessor(_) => Owner::Accessor(BackLink::from(op_inner_val)), + Op::BusOp(_) => Owner::BusOp(BackLink::from(op_inner_val)), Op::Parameter(_) => unreachable!(), Op::Value(_) => unreachable!(), Op::None(span) => Owner::None(*span), @@ -221,6 +229,7 @@ impl Link { Owner::Function(f) => f.to_link(), Owner::Evaluator(e) => e.to_link(), Owner::Accessor(_) => None, + Owner::BusOp(_) => None, Owner::Boundary(_) => None, Owner::Vector(_) => None, Owner::Matrix(_) => None, @@ -243,6 +252,7 @@ impl Link { Owner::Function(_) => None, Owner::Evaluator(_) => None, Owner::Accessor(back) => back.to_link(), + Owner::BusOp(back) => back.to_link(), Owner::Boundary(back) => back.to_link(), Owner::Vector(back) => back.to_link(), Owner::Matrix(back) => back.to_link(), @@ -277,6 +287,7 @@ impl BackLink { Owner::Function(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::Evaluator(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::Accessor(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), + Owner::BusOp(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::Boundary(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::Vector(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::Matrix(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), diff --git a/mir/src/ir/utils.rs b/mir/src/ir/utils.rs new file mode 100644 index 000000000..48ac201f9 --- /dev/null +++ b/mir/src/ir/utils.rs @@ -0,0 +1,342 @@ +use std::{ + collections::BTreeMap, + hash::{DefaultHasher, Hash, Hasher}, +}; + +use air_parser::ast::{Identifier, NamespacedIdentifier}; +use miden_diagnostics::{SourceSpan, Span}; +use pretty_assertions::assert_eq; + +use crate::{ir::*, passes::Visitor, CompileError}; + +pub fn strip_spans(mir: &mut Mir) { + let graph = mir.constraint_graph_mut(); + let mut visitor = StripSpansVisitor::default(); + match visitor.run(graph) { + Ok(_) => {} + Err(e) => { + panic!("Error stripping spans: {:?}", e); + } + } +} + +#[derive(Default)] +pub struct StripSpansVisitor { + _done: BTreeMap, + work_stack: Vec>, +} + +pub fn extract_roots( + graph: &Graph, + include_boundary: bool, + include_integrity: bool, + include_bus: bool, + include_func: bool, + include_eval: bool, +) -> Vec> { + let mut nodes = Vec::new(); + if include_boundary { + let boundary_ref = graph.boundary_constraints_roots.borrow(); + let boundary = boundary_ref.iter().map(|n| n.as_node()); + nodes.extend(boundary); + } + if include_integrity { + let integrity_ref = graph.integrity_constraints_roots.borrow(); + let integrity = integrity_ref.iter().map(|n| n.as_node()); + nodes.extend(integrity); + } + if include_bus { + let buses = graph.get_bus_nodes(); + let bus_columns = buses + .iter() + .flat_map(|b| b.borrow().columns.clone()) + .map(|n| n.as_node()); + let bus_latches = buses + .iter() + .flat_map(|b| b.borrow().latches.clone()) + .map(|n| n.as_node()); + nodes.extend(bus_columns); + nodes.extend(bus_latches); + } + if include_func { + let funcs = graph.get_function_nodes().into_iter().map(|n| n.as_node()); + nodes.extend(funcs); + } + if include_eval { + let evals = graph.get_evaluator_nodes().into_iter().map(|n| n.as_node()); + nodes.extend(evals); + } + nodes +} + +pub fn extract_all_roots(graph: &Graph) -> Vec> { + extract_roots(graph, true, true, true, true, true) +} + +pub fn extract_boundary_roots(graph: &Graph) -> Vec> { + extract_roots(graph, true, false, false, false, false) +} + +pub fn extract_integrity_roots(graph: &Graph) -> Vec> { + extract_roots(graph, false, true, false, false, false) +} + +pub fn extract_bus_roots(graph: &Graph) -> Vec> { + extract_roots(graph, false, false, true, false, false) +} + +pub fn extract_function_roots(graph: &Graph) -> Vec> { + extract_roots(graph, false, false, false, true, false) +} + +pub fn extract_evaluator_roots(graph: &Graph) -> Vec> { + extract_roots(graph, false, false, false, false, true) +} + +impl Visitor for StripSpansVisitor { + fn work_stack(&mut self) -> &mut Vec> { + &mut self.work_stack + } + fn root_nodes_to_visit(&self, graph: &Graph) -> Vec> { + extract_all_roots(graph) + } + + fn visit_function( + &mut self, + _graph: &mut Graph, + function: Link, + ) -> Result<(), CompileError> { + let mut function = function.as_function_mut().unwrap(); + function.span = Default::default(); + Ok(()) + } + + fn visit_evaluator( + &mut self, + _graph: &mut Graph, + evaluator: Link, + ) -> Result<(), CompileError> { + let mut evaluator = evaluator.as_evaluator_mut().unwrap(); + evaluator.span = Default::default(); + Ok(()) + } + + fn visit_enf(&mut self, _graph: &mut Graph, enf: Link) -> Result<(), CompileError> { + let mut enf = enf.as_enf_mut().unwrap(); + enf.span = Default::default(); + Ok(()) + } + + fn visit_boundary( + &mut self, + _graph: &mut Graph, + boundary: Link, + ) -> Result<(), CompileError> { + let mut boundary = boundary.as_boundary_mut().unwrap(); + boundary.span = Default::default(); + Ok(()) + } + + fn visit_add(&mut self, _graph: &mut Graph, add: Link) -> Result<(), CompileError> { + let mut add = add.as_add_mut().unwrap(); + add.span = Default::default(); + Ok(()) + } + + fn visit_sub(&mut self, _graph: &mut Graph, sub: Link) -> Result<(), CompileError> { + let mut sub = sub.as_sub_mut().unwrap(); + sub.span = Default::default(); + Ok(()) + } + + fn visit_mul(&mut self, _graph: &mut Graph, mul: Link) -> Result<(), CompileError> { + let mut mul = mul.as_mul_mut().unwrap(); + mul.span = Default::default(); + Ok(()) + } + + fn visit_exp(&mut self, _graph: &mut Graph, exp: Link) -> Result<(), CompileError> { + let mut exp = exp.as_exp_mut().unwrap(); + exp.span = Default::default(); + Ok(()) + } + + fn visit_if(&mut self, _graph: &mut Graph, if_node: Link) -> Result<(), CompileError> { + let mut if_node = if_node.as_if_mut().unwrap(); + if_node.span = Default::default(); + Ok(()) + } + + fn visit_for(&mut self, _graph: &mut Graph, for_node: Link) -> Result<(), CompileError> { + let mut for_node = for_node.as_for_mut().unwrap(); + for_node.span = Default::default(); + Ok(()) + } + + fn visit_call(&mut self, _graph: &mut Graph, call: Link) -> Result<(), CompileError> { + let mut call = call.as_call_mut().unwrap(); + call.span = Default::default(); + Ok(()) + } + + fn visit_fold(&mut self, _graph: &mut Graph, fold: Link) -> Result<(), CompileError> { + let mut fold = fold.as_fold_mut().unwrap(); + fold.span = Default::default(); + Ok(()) + } + + fn visit_vector(&mut self, _graph: &mut Graph, vector: Link) -> Result<(), CompileError> { + let mut vector = vector.as_vector_mut().unwrap(); + vector.span = Default::default(); + Ok(()) + } + + fn visit_matrix(&mut self, _graph: &mut Graph, matrix: Link) -> Result<(), CompileError> { + let mut matrix = matrix.as_matrix_mut().unwrap(); + matrix.span = Default::default(); + Ok(()) + } + + fn visit_accessor( + &mut self, + _graph: &mut Graph, + accessor: Link, + ) -> Result<(), CompileError> { + let mut accessor = accessor.as_accessor_mut().unwrap(); + accessor.span = Default::default(); + Ok(()) + } + + fn visit_bus_op(&mut self, _graph: &mut Graph, bus_op: Link) -> Result<(), CompileError> { + let mut bus_op = bus_op.as_bus_op_mut().unwrap(); + bus_op.span = Default::default(); + Ok(()) + } + + fn visit_parameter( + &mut self, + _graph: &mut Graph, + parameter: Link, + ) -> Result<(), CompileError> { + let mut parameter = parameter.as_parameter_mut().unwrap(); + parameter.span = Default::default(); + Ok(()) + } + + fn visit_value(&mut self, _graph: &mut Graph, value: Link) -> Result<(), CompileError> { + let mut value = value.as_value_mut().unwrap(); + value.value.span = Default::default(); + match &mut value.value.value { + MirValue::Constant(_) => {} + MirValue::TraceAccess(_) => {} + MirValue::PeriodicColumn(v) => { + v.name.module.0 = Span::new(SourceSpan::default(), v.name.module.0.item); + match v.name.item { + NamespacedIdentifier::Function(f) => { + v.name.item = NamespacedIdentifier::Function(Identifier::new( + SourceSpan::default(), + f.0.item, + )); + } + NamespacedIdentifier::Binding(b) => { + v.name.item = NamespacedIdentifier::Binding(Identifier::new( + SourceSpan::default(), + b.0.item, + )); + } + }; + } + MirValue::PublicInput(v) => { + v.name.0 = Span::new(SourceSpan::default(), v.name.0.item); + } + // TODO: Will be used when handling variable-length public inputs + /*MirValue::PublicInputBinding(v) => { + v.name.0 = Span::new(SourceSpan::default(), v.name.0.item); + }*/ + MirValue::RandomValue(_) => {} + MirValue::TraceAccessBinding(_) => {} + MirValue::RandomValueBinding(_) => {} + MirValue::BusAccess(_) => {} + MirValue::Null => {} + } + Ok(()) + } +} + +pub fn hash(val: &T) -> u64 { + let mut hasher = DefaultHasher::new(); + val.hash(&mut hasher); + hasher.finish() +} + +fn extract_and_compare_mir( + lhs: &mut Mir, + rhs: &mut Mir, + extract: impl Fn(&Graph) -> Vec>, +) -> bool { + strip_spans(lhs); + strip_spans(rhs); + let lhs = extract(lhs.constraint_graph()); + let rhs = extract(rhs.constraint_graph()); + hash(&lhs) == hash(&rhs) +} + +pub fn compare_mir(lhs: &mut Mir, rhs: &mut Mir) -> bool { + extract_and_compare_mir(lhs, rhs, extract_all_roots) +} + +pub fn compare_boundary(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_compare_mir(lhs, rhs, extract_boundary_roots); +} + +pub fn compare_integrity(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_compare_mir(lhs, rhs, extract_integrity_roots); +} + +pub fn compare_bus(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_compare_mir(lhs, rhs, extract_bus_roots); +} + +pub fn compare_function(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_compare_mir(lhs, rhs, extract_function_roots); +} + +pub fn compare_evaluator(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_compare_mir(lhs, rhs, extract_evaluator_roots); +} + +fn extract_and_assert_mir_eq( + lhs: &mut Mir, + rhs: &mut Mir, + extract: impl Fn(&Graph) -> Vec>, +) { + strip_spans(lhs); + strip_spans(rhs); + let lhs = extract(lhs.constraint_graph()); + let rhs = extract(rhs.constraint_graph()); + assert_eq!(lhs, rhs); +} + +pub fn assert_mir_eq(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_assert_mir_eq(lhs, rhs, extract_all_roots); +} + +pub fn assert_boundary_eq(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_assert_mir_eq(lhs, rhs, extract_boundary_roots); +} + +pub fn assert_integrity_eq(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_assert_mir_eq(lhs, rhs, extract_integrity_roots); +} + +pub fn assert_bus_eq(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_assert_mir_eq(lhs, rhs, extract_bus_roots); +} + +pub fn assert_function_eq(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_assert_mir_eq(lhs, rhs, extract_function_roots); +} + +pub fn assert_evaluator_eq(lhs: &mut Mir, rhs: &mut Mir) { + extract_and_assert_mir_eq(lhs, rhs, extract_evaluator_roots); +} diff --git a/mir/src/passes/bus_op_expand.rs b/mir/src/passes/bus_op_expand.rs new file mode 100644 index 000000000..dee879f18 --- /dev/null +++ b/mir/src/passes/bus_op_expand.rs @@ -0,0 +1,427 @@ +use std::ops::Deref; + +use air_parser::ast::{AccessType, BusType}; +use air_pass::Pass; +use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Spanned}; + +use super::duplicate_node; +use crate::{ + ir::{ + Accessor, Add, BusAccess, BusOpKind, ConstantValue, Enf, Link, Mir, MirValue, Mul, Op, + SpannedMirValue, Sub, Value, + }, + CompileError, +}; + +/// TODO MIR: +/// If needed, implement bus operation expand pass on MIR +/// See https://github.com/0xPolygonMiden/air-script/issues/183 +/// +pub struct BusOpExpand<'a> { + #[allow(unused)] + diagnostics: &'a DiagnosticsHandler, +} + +impl Pass for BusOpExpand<'_> { + type Input<'a> = Mir; + type Output<'a> = Mir; + type Error = CompileError; + + fn run<'a>(&mut self, mut ir: Self::Input<'a>) -> Result, Self::Error> { + let mut max_num_random_values = ir.num_random_values as usize; + // TODO: When removing aux and rand values, use the following instead + + /*if ir.num_random_values != 0 { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("No random values should be set at this point") + .emit(); + return Err(CompileError::Failed); + }; + let mut max_num_random_values = 0;*/ + + let graph = ir.constraint_graph_mut(); + + let buses = graph.buses.clone(); + + for (_ident, bus) in buses { + let bus_type = bus.borrow().bus_type.clone(); + let columns = bus.borrow().columns.clone(); // columns are the bus_operations (add or remove of a Vec of arguments) + let latches = bus.borrow().latches.clone(); // latches are the selectors + let first = bus.borrow().get_first().clone(); + let last = bus.borrow().get_last().clone(); + //println!("first: {:?}", &first); + //println!("last: {:?}", &last); + + let bus_access = Value::create(SpannedMirValue { + span: bus.borrow().span(), + value: MirValue::BusAccess(BusAccess { + bus: bus.clone(), + row_offset: 0, + }), + }); + let bus_access_with_offset = Accessor::create( + duplicate_node(bus_access.clone(), &mut Default::default()), + AccessType::Default, + 1, + bus.borrow().span(), + ); + + // Expand bus boundary constraints first + self.handle_boundary_constraint(bus_type.clone(), first/*, air_parser::ast::Boundary::First, bus_access.clone(), bus.borrow().span()*/); + self.handle_boundary_constraint( + bus_type.clone(), + last, /*, air_parser::ast::Boundary::Last, bus_access.clone(), bus.borrow().span()*/ + ); + + // Then, expend bus integrity constraints + match bus_type { + BusType::Unit => { + // Example: + // p.add(a, b) when s + // p.rem(c, d) when (1 - s) + // => p' * (( A0 + A1 c + A2 d ) ( 1 - s ) + s) = p * ( A0 + A1 a + A2 b ) s + 1 - s + + // p' * ( columns removed combined with alphas ) = p * ( columns added combined with alphas ) + + let mut p_factor = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: crate::ir::MirValue::Constant(crate::ir::ConstantValue::Felt(1)), + }); + let mut p_prime_factor = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::Constant(crate::ir::ConstantValue::Felt(1)), + }); + + for (column, latch) in columns.iter().zip(latches.iter()) { + let bus_op = column.as_bus_op().unwrap(); + let bus_op_kind = bus_op.kind.clone(); + let bus_op_args = bus_op.args.clone(); + + // 1. Combine args with alphas + // 1.1 Start with the first alpha + let mut args_combined = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::RandomValue(0), + }); + max_num_random_values = max_num_random_values.max(1); + for (index, arg) in bus_op_args.iter().enumerate() { + // 1.2 Create corresponding alpha + let alpha = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::RandomValue(index + 1), + }); + max_num_random_values = max_num_random_values.max(index + 1); + + // 1.3 Multiply arg with alpha + let arg_times_alpha = + Mul::create(arg.clone(), alpha, SourceSpan::default()); + + // 1.4 Combine with other args + args_combined = + Add::create(args_combined, arg_times_alpha, SourceSpan::default()); + } + + // 2. Multiply by latch + let args_combined_with_latch = + Mul::create(args_combined, latch.clone(), SourceSpan::default()); + + // 3. add inverse of latch + let args_combined_with_latch_and_latch_inverse = Add::create( + args_combined_with_latch, + Sub::create( + Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::Constant(crate::ir::ConstantValue::Felt(1)), + }), + duplicate_node(latch.clone(), &mut Default::default()), + SourceSpan::default(), + ), + SourceSpan::default(), + ); + + // 4. add to p_factor or p_prime_factor (depending on bus_op_kind: add: p, rem: p_prime) + match bus_op_kind { + BusOpKind::Add => { + p_factor = Add::create( + p_factor, + args_combined_with_latch_and_latch_inverse, + SourceSpan::default(), + ); + } + BusOpKind::Rem => { + p_prime_factor = Add::create( + p_prime_factor, + args_combined_with_latch_and_latch_inverse, + SourceSpan::default(), + ); + } + } + } + + // 5. Multiply the factors with the bus column (with and without offset for p' and p respectively) + let p_prod = Mul::create(p_factor, bus_access, SourceSpan::default()); + let p_prime_prod = Mul::create( + p_prime_factor, + bus_access_with_offset, + SourceSpan::default(), + ); + + // 6. Create the resulting constraint and insert it into the graph + let resulting_constraint = Enf::create( + Sub::create(p_prod, p_prime_prod, SourceSpan::default()), + SourceSpan::default(), + ); + + graph.insert_integrity_constraints_root(resulting_constraint); + } + BusType::Mult => { + // Example: + // q.add(a, b, c) for d + // q.rem(e, f, g) when s + // => q' + s / ( A0 + A1 e + A2 f + A3 g ) = q + d / ( A0 + A1 a + A2 b + A3 c ) + + // q' + s / ( columns removed combined with alphas ) = q + d / ( columns added combined with alphas ) + // PROD * q' + s * ( columns added combined with alphas ) = PROD * q + d * ( columns removed combined with alphas ) + + // 1. Compute all the factors + let mut factors = vec![]; + for column in columns.iter() { + let bus_op = column.as_bus_op().unwrap(); + let bus_op_args = bus_op.args.clone(); + + // 1. Combine args with alphas + // 1.1 Start with the first alpha + let mut args_combined = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::RandomValue(0), + }); + max_num_random_values = max_num_random_values.max(1); + for (index, arg) in bus_op_args.iter().enumerate() { + // 1.2 Create corresponding alpha + let alpha = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::RandomValue(index + 1), + }); + max_num_random_values = max_num_random_values.max(index + 1); + + // 1.3 Multiply arg with alpha + let arg_times_alpha = + Mul::create(arg.clone(), alpha, SourceSpan::default()); + + // 1.4 Combine with other args + args_combined = + Add::create(args_combined, arg_times_alpha, SourceSpan::default()); + } + + factors.push(args_combined); + } + + // 2. Compute the product of all factors (will be used to mult q and q') + let mut total_factors = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::Constant(crate::ir::ConstantValue::Felt(1)), + }); + for factor in factors.iter() { + total_factors = + Mul::create(total_factors, factor.clone(), SourceSpan::default()); + } + + // 3. For each column, compute the product of all factors except the one of the current column, and multiply it with the latch + let mut terms_added_to_bus = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::Constant(crate::ir::ConstantValue::Felt(0)), + }); + let mut terms_removed_from_bus = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::Constant(crate::ir::ConstantValue::Felt(0)), + }); + for (index, (column, latch)) in columns.iter().zip(latches.iter()).enumerate() { + let bus_op = column.as_bus_op().unwrap(); + let bus_op_kind = bus_op.kind.clone(); + + // 3.1 Compute the product of all factors except the one of the current column + let mut factors_without_current = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::Constant(crate::ir::ConstantValue::Felt(1)), + }); + for (i, factor) in factors.iter().enumerate() { + if i != index { + factors_without_current = Mul::create( + factors_without_current, + factor.clone(), + SourceSpan::default(), + ); + } + } + + // 3.2 Multiply by latch + let factors_without_current_with_latch = Mul::create( + factors_without_current, + latch.clone(), + SourceSpan::default(), + ); + + // 3.3 Depending on the bus_op_kind, add to q_factor or q_prime_factor + match bus_op_kind { + BusOpKind::Add => { + terms_added_to_bus = Add::create( + terms_added_to_bus, + factors_without_current_with_latch, + SourceSpan::default(), + ); + } + BusOpKind::Rem => { + terms_removed_from_bus = Add::create( + terms_removed_from_bus, + factors_without_current_with_latch, + SourceSpan::default(), + ); + } + } + } + + // 4. Add all the terms together + let q_prod = + Mul::create(total_factors.clone(), bus_access, SourceSpan::default()); + let q_prime_prod = Mul::create( + total_factors.clone(), + bus_access_with_offset, + SourceSpan::default(), + ); + let q_term = + Add::create(q_prod, terms_added_to_bus.clone(), SourceSpan::default()); + let q_prime_term = Add::create( + q_prime_prod, + terms_removed_from_bus.clone(), + SourceSpan::default(), + ); + + // 5. Create the resulting constraint and insert it into the graph + let resulting_constraint = Enf::create( + Sub::create(q_term, q_prime_term, SourceSpan::default()), + SourceSpan::default(), + ); + + graph.insert_integrity_constraints_root(resulting_constraint); + } + } + } + + ir.num_random_values = max_num_random_values as u16; + + Ok(ir) + } +} + +impl<'a> BusOpExpand<'a> { + #[allow(unused)] + pub fn new(diagnostics: &'a DiagnosticsHandler) -> Self { + Self { diagnostics } + } + + fn handle_boundary_constraint( + &self, + bus_type: BusType, + link: Link, /*, boundary: air_parser::ast::Boundary, bus_access: Link, bus_span: SourceSpan*/ + ) { + let mut to_update = None; + + match link.borrow().deref() { + Op::Value(value) => { + match value.value.value { + // TODO: Will be used when handling variable-length public inputs + /*MirValue::PublicInputBinding(public_input_binding) => { + + },*/ + MirValue::Null => { + // Empty bus + + let unit_constant = match bus_type { + BusType::Unit => 1, // Product, unit for product is 1 + BusType::Mult => 0, // Sum of inverses, unit for sum is 0 + }; + let unit_val = Value::create(SpannedMirValue { + span: SourceSpan::default(), + value: MirValue::Constant(ConstantValue::Felt(unit_constant)), + }); + + to_update = Some(unit_val); + + /*let bus_boundary = Boundary::create( + duplicate_node(bus_access.clone(), &mut Default::default()), + boundary, + bus_span, + ); + + let resulting_constraint = Enf::create( + Sub::create(bus_boundary, unit_val, SourceSpan::default()), + SourceSpan::default(), + ); + + //graph.insert_boundary_constraints_root(resulting_constraint);*/ + } + _ => unreachable!(), + } + } + Op::None(_) => {} + _ => unreachable!(), + } + + if let Some(to_update) = to_update { + link.set(&to_update); + } + } +} + +/*impl Visitor for BusOpExpand<'_> { + fn work_stack(&mut self) -> &mut Vec> { + &mut self.work_stack + } + fn root_nodes_to_visit( + &self, + graph: &crate::ir::Graph, + ) -> Vec> { + let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); + let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); + let combined_roots = boundary_constraints_roots_ref + .clone() + .into_iter() + .map(|bc| bc.as_node()) + .chain( + integrity_constraints_roots_ref + .clone() + .into_iter() + .map(|ic| ic.as_node()), + ); + combined_roots.collect() + } + + fn visit_node(&mut self, graph: &mut Graph, node: Link) -> Result<(), CompileError> { + let updated_op: Result>, CompileError> = match node.borrow().deref() { + Node::BusOp(bus_op) => { + let bus_op_link: Link = bus_op.clone().into(); + let mut updated_node = None; + + { + // safe to unwrap because we just dispatched on it + let bus_op_ref = bus_op_link.as_bus_op().unwrap(); + let bus = bus_op_ref.bus.clone(); + let bus_kind = bus.borrow().bus_type.clone(); + let bus_operator = bus_op_ref.kind.clone(); + let args = bus_op_ref.args.clone(); + } + + Ok(updated_node) + } + _ => Ok(None), + }; + + // We update the node if needed + if let Some(updated_op) = updated_op? { + node.as_op().unwrap().set(&updated_op); + } + + Ok(()) + } +}*/ diff --git a/mir/src/passes/inlining.rs b/mir/src/passes/inlining.rs index f3addad2a..0d5dfab85 100644 --- a/mir/src/passes/inlining.rs +++ b/mir/src/passes/inlining.rs @@ -204,6 +204,11 @@ impl Visitor for InliningFirstPass<'_> { let evaluators = graph.get_evaluator_nodes(); let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); + let bus_roots: Vec<_> = graph + .buses + .values() + .flat_map(|b| b.borrow().clone().columns.into_iter().collect::>()) + .collect(); let combined_roots = boundary_constraints_roots_ref .clone() @@ -215,6 +220,7 @@ impl Visitor for InliningFirstPass<'_> { .into_iter() .map(|ic| ic.as_node()), ) + .chain(bus_roots.into_iter().map(|b| b.as_node())) .chain(evaluators.into_iter().map(|e| e.as_node())) .chain(functions.into_iter().map(|f| f.as_node())); combined_roots.collect() @@ -333,7 +339,6 @@ impl Visitor for InliningSecondPass<'_> { .unwrap() .clone(); updated_op = Some(new_node); - //println!("Updating call node of function: {:?}", updated_op); } else { // We have finished inlining the body, we can now replace the Call node with all the body let mut new_nodes = Vec::new(); @@ -359,7 +364,6 @@ impl Visitor for InliningSecondPass<'_> { let new_nodes_vector = Vector::create(new_nodes, span); updated_op = Some(new_nodes_vector); - //println!("Updating call node of evaluator: {:?}", updated_op); } // Reset context to None diff --git a/mir/src/passes/mod.rs b/mir/src/passes/mod.rs index 62508fed8..89715deee 100644 --- a/mir/src/passes/mod.rs +++ b/mir/src/passes/mod.rs @@ -1,10 +1,13 @@ +mod bus_op_expand; mod inlining; mod translate; mod unrolling; mod visitor; +pub use bus_op_expand::BusOpExpand; pub use inlining::Inlining; pub use translate::AstToMir; pub use unrolling::Unrolling; +pub use visitor::Visitor; // Note: ConstantPropagation and ValueNumbering are not implemented yet in the MIR //mod constant_propagation; //mod value_numbering; @@ -17,8 +20,8 @@ use std::ops::Deref; use miden_diagnostics::Spanned; use crate::ir::{ - Accessor, Add, Boundary, Call, Enf, Exp, Fold, For, If, Link, Matrix, Mul, Node, Op, Owner, - Parameter, Parent, Sub, Value, Vector, + Accessor, Add, Boundary, BusOp, Call, Enf, Exp, Fold, For, If, Link, Matrix, Mul, Node, Op, + Owner, Parameter, Parent, Sub, Value, Vector, }; /// Helper to duplicate a MIR node and its children recursively @@ -173,6 +176,16 @@ pub fn duplicate_node( let new_indexable = duplicate_node(indexable, current_replace_map); Accessor::create(new_indexable, access_type, offset, accessor.span()) } + Op::BusOp(bus_op) => { + let bus = bus_op.bus.clone(); + let kind = bus_op.kind.clone(); + let args: Vec> = bus_op + .args + .iter() + .map(|x| duplicate_node(x.clone(), current_replace_map)) + .collect(); + BusOp::create(bus, kind, args, bus_op.span()) + } Op::Parameter(parameter) => { let owner_ref = parameter .ref_node @@ -415,6 +428,13 @@ pub fn duplicate_node_or_replace( let new_node = Accessor::create(new_indexable, access_type, offset, accessor.span()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); } + Op::BusOp(bus_op) => { + let bus = bus_op.bus.clone(); + let kind = bus_op.kind.clone(); + let args = bus_op.args.clone(); + let new_node = BusOp::create(bus, kind, args, bus_op.span()); + current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); + } Op::Parameter(parameter) => { let owner_ref = parameter .ref_node diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 06d5ba17d..61a22c382 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -6,11 +6,13 @@ use air_parser::{ast, symbols, LexicalScope}; use air_pass::Pass; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; -use crate::ir::{Accessor, Add, Boundary, Enf, Evaluator, Exp, Matrix, Mul, Owner, Root, Sub}; +use crate::ir::BusAccess; +//use crate::ir::PublicInputBinding; use crate::{ ir::{ - Builder, Call, ConstantValue, Fold, FoldOperator, For, Function, Link, Mir, MirType, - MirValue, Op, Parameter, PublicInputAccess, SpannedMirValue, TraceAccess, + Accessor, Add, Boundary, Builder, Bus, BusOp, BusOpKind, Call, ConstantValue, Enf, + Evaluator, Exp, Fold, FoldOperator, For, Function, Link, Matrix, Mir, MirType, MirValue, + Mul, Op, Owner, Parameter, PublicInputAccess, Root, SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, Vector, }, passes::duplicate_node, @@ -84,11 +86,18 @@ impl<'a> MirBuilder<'a> { let trace_columns = &self.program.trace_columns; let boundary_constraints = &self.program.boundary_constraints; let integrity_constraints = &self.program.integrity_constraints; + let buses = &self.program.buses; self.mir.trace_columns.clone_from(trace_columns); self.mir.num_random_values = random_values.as_ref().map(|rv| rv.size as u16).unwrap_or(0); self.mir.periodic_columns = self.program.periodic_columns.clone(); self.mir.public_inputs = self.program.public_inputs.clone(); + for (qual_ident, ast_bus) in buses.iter() { + let bus = self.translate_bus_definition(ast_bus)?; + self.mir + .constraint_graph_mut() + .insert_bus(*qual_ident, bus)?; + } for (ident, function) in &self.program.functions { self.translate_function_signature(ident, function)?; @@ -114,6 +123,10 @@ impl<'a> MirBuilder<'a> { Ok(()) } + fn translate_bus_definition(&mut self, bus: &'a ast::Bus) -> Result, CompileError> { + Ok(Bus::create(bus.name, bus.bus_type.clone(), bus.span())) + } + fn translate_evaluator_signature( &mut self, ident: &'a ast::QualifiedIdentifier, @@ -128,9 +141,7 @@ impl<'a> MirBuilder<'a> { for trace_segment in &ast_eval.params { let mut all_params_flatten_for_trace_segment = Vec::new(); - // println!("trace_segment: {:#?}", trace_segment); for binding in &trace_segment.bindings { - // println!("binding: {:#?}", binding); let span = binding.name.map_or(SourceSpan::UNKNOWN, |n| n.span()); let params = self.translate_params_ev(span, binding.name.as_ref(), &binding.ty, &mut i)?; @@ -342,9 +353,6 @@ impl<'a> MirBuilder<'a> { let func = func; for stmt in body { let op = self.translate_statement(stmt)?; - //println!("statement: {:#?}", stmt); - //println!("op: {:#?}", op); - //println!(); match func.clone().borrow().deref() { Root::Function(f) => f.body.borrow_mut().push(op.clone()), Root::Evaluator(e) => e.body.borrow_mut().push(op.clone()), @@ -373,6 +381,7 @@ impl<'a> MirBuilder<'a> { ast::Statement::Enforce(enf) => self.translate_enforce(enf), ast::Statement::EnforceIf(enf, cond) => self.translate_enforce_if(enf, cond), ast::Statement::EnforceAll(list_comp) => self.translate_enforce_all(list_comp), + ast::Statement::BusEnforce(list_comp) => self.translate_bus_enforce(list_comp), } } fn translate_let(&mut self, let_stmt: &'a ast::Let) -> Result, CompileError> { @@ -398,6 +407,7 @@ impl<'a> MirBuilder<'a> { ast::Expr::Call(c) => self.translate_call(c), ast::Expr::ListComprehension(lc) => self.translate_list_comprehension(lc), ast::Expr::Let(l) => self.translate_let(l), + ast::Expr::BusOperation(_) | ast::Expr::Null(_) => todo!(), } } @@ -466,6 +476,74 @@ impl<'a> MirBuilder<'a> { node } + fn translate_bus_enforce( + &mut self, + list_comp: &'a ast::ListComprehension, + ) -> Result, CompileError> { + let bus_op = self.translate_scalar_expr(&list_comp.body)?; + if list_comp.iterables.len() != 1 { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("expected a single iterable in bus enforce") + .with_primary_label( + list_comp.span(), + format!( + "expected a single iterable in bus enforce, got this instead: \n{:#?}", + list_comp.iterables + ), + ) + .emit(); + return Err(CompileError::Failed); + } + // Note: safe to unwrap because we checked the length above + let ast_iterables = list_comp.iterables.first().unwrap(); + // sanity check + match ast_iterables { + ast::Expr::Range(ast::RangeExpr { start, end, .. }) => { + let start = match start { + ast::RangeBound::Const(Span { item: val, .. }) => val, + _ => unimplemented!(), + }; + let end = match end { + ast::RangeBound::Const(Span { item: val, .. }) => val, + _ => unimplemented!(), + }; + if *start != 0 || *end != 1 { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("Bus comprehensions can only target a single latch") + .with_primary_label( + list_comp.span(), + format!( + "expected a range with a single value, got this instead: \n{:#?}", + list_comp.iterables + ), + ) + .emit(); + return Err(CompileError::Failed); + }; + } + _ => unimplemented!(), + }; + let sel = match list_comp.selector.as_ref() { + Some(selector) => self.translate_scalar_expr(selector)?, + None => todo!(), // Should we always have a selector? + }; + bus_op + .as_bus_op_mut() + .unwrap() + ._latch + .borrow_mut() + .clone_from(&sel.borrow()); + let bus_op_clone = bus_op.clone(); + let bus_op_ref = bus_op_clone.as_bus_op_mut().unwrap(); + let bus_link = bus_op_ref.bus.to_link().unwrap(); + let mut bus = bus_link.borrow_mut(); + bus.latches.push(sel.clone()); + bus.columns.push(bus_op.clone()); + Ok(bus_op) + } + fn insert_enforce(&mut self, node: Link) -> Result, CompileError> { let node_to_add = if let Op::Enf(_) = node.clone().borrow().deref() { node @@ -564,13 +642,35 @@ impl<'a> MirBuilder<'a> { }) .build(); Ok(node) + } else if let Some(bus) = self.mir.constraint_graph().get_bus_link(&qual_ident) { + let node = Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::BusAccess(BusAccess::new(bus.clone(), access.offset)), + }) + .build(); + Ok(node) } else { // This is a qualified reference that should have been eliminated // during inlining or constant propagation, but somehow slipped through. - unreachable!( - "expected reference to periodic column, got `{:?}` instead", - qual_ident - ); + self.diagnostics + .diagnostic(Severity::Error) + //", got `{:#?}` of {:#?} instead.", + .with_message("expected reference to periodic column") + .with_primary_label( + qual_ident.span(), + format!( + "expected reference to periodic column, got `{:#?}`", + qual_ident + ), + ) + .with_secondary_label( + access.span(), + format!("in this access expression `{:#?}`", access), + ) + .emit(); + //unreachable!("expected reference to periodic column in `{:#?}`", access); + Err(CompileError::Failed) } } // This must be one of public inputs, random values, or trace columns @@ -593,6 +693,49 @@ impl<'a> MirBuilder<'a> { ) -> Result, CompileError> { let lhs = self.translate_scalar_expr(&bin_op.lhs)?; let rhs = self.translate_scalar_expr(&bin_op.rhs)?; + + // Check if bin_op is a bus constraint, if so, add it to the Link + if let (Op::Boundary(lhs_boundary), true) = (lhs.borrow().deref(), self.in_boundary) { + let lhs_child = lhs_boundary.expr.clone(); + let kind = lhs_boundary.kind; + + let lhs_child_as_value_ref = lhs_child.as_value(); + if let Some(val_ref) = lhs_child_as_value_ref { + let SpannedMirValue { span: _span, value } = val_ref.value.clone(); + if let MirValue::BusAccess(bus_access) = value { + let bus = bus_access.bus; + match kind { + ast::Boundary::First => { + bus.borrow_mut().set_first(rhs.clone()).map_err(|_| { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("bus boundary constraint already set") + .with_primary_label( + bin_op.span(), + "bus boundary constraint already set", + ) + .emit(); + CompileError::Failed + })?; + } + ast::Boundary::Last => { + bus.borrow_mut().set_last(rhs.clone()).map_err(|_| { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("bus boundary constraint already set") + .with_primary_label( + bin_op.span(), + "bus boundary constraint already set", + ) + .emit(); + CompileError::Failed + })?; + } + } + } + } + } + match bin_op.op { ast::BinaryOp::Add => { let node = Add::builder().lhs(lhs).rhs(rhs).span(bin_op.span()).build(); @@ -618,8 +761,6 @@ impl<'a> MirBuilder<'a> { } fn translate_call(&mut self, call: &'a ast::Call) -> Result, CompileError> { - //println!("CALL ARGS: {:#?}", call); - // First, resolve the callee, panic if it's not resolved let resolved_callee = call.callee.resolved().unwrap(); @@ -807,6 +948,11 @@ impl<'a> MirBuilder<'a> { ast::ScalarExpr::Binary(b) => self.translate_binary_op(b), ast::ScalarExpr::Call(c) => self.translate_call(c), ast::ScalarExpr::Let(l) => self.translate_let(l), + ast::ScalarExpr::Null(_) => Ok(Value::create(SpannedMirValue { + span: scalar_expr.span(), + value: MirValue::Null, + })), + ast::ScalarExpr::BusOperation(bo) => self.translate_bus_operation(bo), } } @@ -836,6 +982,71 @@ impl<'a> MirBuilder<'a> { Ok(node) } + fn translate_bus_operation( + &mut self, + ast_bus_op: &'a ast::BusOperation, + ) -> Result, CompileError> { + let Some(bus_ident) = ast_bus_op.bus.resolved() else { + self.diagnostics + .diagnostic(Severity::Error) + .with_message(format!( + "expected a resolved bus identifier, got `{:#?}`", + ast_bus_op.bus + )) + .with_primary_label( + ast_bus_op.bus.span(), + "expected a resolved bus identifier here", + ) + .emit(); + return Err(CompileError::Failed); + }; + let Some(bus) = self.mir.constraint_graph().get_bus_link(&bus_ident) else { + self.diagnostics + .diagnostic(Severity::Error) + .with_message(format!( + "expected a known bus identifier here, got `{:#?}`", + ast_bus_op.bus + )) + .with_primary_label(ast_bus_op.bus.span(), "Unknown bus identifier") + .emit(); + return Err(CompileError::Failed); + }; + let bus_op_kind = match ast_bus_op.op { + ast::BusOperator::Add => BusOpKind::Add, + ast::BusOperator::Rem => BusOpKind::Rem, + }; + + let mut bus_op = BusOp::builder() + .span(ast_bus_op.span()) + .bus(bus) + .kind(bus_op_kind); + for arg in ast_bus_op.args.iter() { + let mut arg_node = self.translate_expr(arg)?; + let accessor_mut = arg_node.clone(); + if let Some(accessor) = accessor_mut.as_accessor_mut() { + match accessor.access_type { + AccessType::Default => { + arg_node = accessor.indexable.clone(); + } + _ => { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("expected default access type") + .with_primary_label( + arg.span(), + "expected default access type, got this instead", + ) + .emit(); + return Err(CompileError::Failed); + } + } + } + bus_op = bus_op.args(arg_node); + } + let bus_op = bus_op.build(); + Ok(bus_op) + } + fn translate_const( &mut self, c: &ast::ConstantExpr, @@ -968,23 +1179,68 @@ impl<'a> MirBuilder<'a> { .build()); } + // TODO: Will be used (instead of the if let above) when handling variable-length public inputs + /*match self.public_input_access(access) { + (Some(public_input), None) => { + return Ok(Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::PublicInput(public_input), + }) + .build()); + } + (None, Some(public_input_binding)) => { + return Ok(Value::builder() + .value(SpannedMirValue { + span: access.span(), + value: MirValue::PublicInputBinding(public_input_binding), + }) + .build()); + } + _ => {} + }*/ + panic!("undefined variable: {:?}", access); } // Check assumptions, probably this assumed that the inlining pass did some work fn public_input_access(&self, access: &ast::SymbolAccess) -> Option { let public_input = self.mir.public_inputs.get(access.name.as_ref())?; - if let AccessType::Index(index) = access.access_type { - Some(PublicInputAccess::new(public_input.name, index)) - } else { - // This should have been caught earlier during compilation - unreachable!( - "unexpected public input access type encountered during lowering: {:#?}", - access - ) + match access.access_type { + AccessType::Index(index) => Some(PublicInputAccess::new(public_input.name, index)), + _ => { + // This should have been caught earlier during compilation + unreachable!( + "unexpected public input access type encountered during lowering: {:#?}", + access + ) + } } } + // TODO: Will be used when handling variable-length public inputs + /*fn public_input_access( + &self, + access: &ast::SymbolAccess, + ) -> (Option, Option) { + let Some(public_input) = self.mir.public_inputs.get(access.name.as_ref()) else { + return (None, None); + }; + match access.access_type { + AccessType::Default => (None, Some(PublicInputBinding::new(public_input.name))), + AccessType::Index(index) => { + (Some(PublicInputAccess::new(public_input.name, index)), None) + } + _ => { + // This should have been caught earlier during compilation + unreachable!( + "unexpected public input access type encountered during lowering: {:#?}", + access + ) + } + } + }*/ + // Check assumptions, probably this assumed that the inlining pass did some work fn random_value_access(&self, access: &ast::SymbolAccess) -> Option { let rv = self.random_values.as_ref()?; diff --git a/mir/src/passes/unrolling.rs b/mir/src/passes/unrolling.rs index 50ccbe189..658f26435 100644 --- a/mir/src/passes/unrolling.rs +++ b/mir/src/passes/unrolling.rs @@ -214,6 +214,8 @@ impl UnrollingFirstPass<'_> { updated_value = Some(Vector::create(vec, value_ref.span())); } + MirValue::BusAccess(_) => {} + MirValue::Null => {} //MirValue::PublicInputBinding(_public_input_binding) => {} } } @@ -844,6 +846,11 @@ impl Visitor for UnrollingFirstPass<'_> { fn root_nodes_to_visit(&self, graph: &Graph) -> Vec> { let boundary_constraints_roots_ref = graph.boundary_constraints_roots.borrow(); let integrity_constraints_roots_ref = graph.integrity_constraints_roots.borrow(); + let bus_roots: Vec<_> = graph + .buses + .values() + .flat_map(|b| b.borrow().clone().columns.into_iter().collect::>()) + .collect(); let combined_roots = boundary_constraints_roots_ref .clone() .into_iter() @@ -853,7 +860,8 @@ impl Visitor for UnrollingFirstPass<'_> { .clone() .into_iter() .map(|ic| ic.as_node()), - ); + ) + .chain(bus_roots.into_iter().map(|b| b.as_node())); combined_roots.collect() } @@ -892,6 +900,7 @@ impl Visitor for UnrollingFirstPass<'_> { Node::Accessor(a) => { to_link_and(a.clone(), graph, |g, el| self.visit_accessor_bis(g, el)) } + Node::BusOp(_b) => Ok(None), Node::Parameter(p) => { to_link_and(p.clone(), graph, |g, el| self.visit_parameter_bis(g, el)) } diff --git a/mir/src/passes/visitor.rs b/mir/src/passes/visitor.rs index 9df811f82..89d3d46e1 100644 --- a/mir/src/passes/visitor.rs +++ b/mir/src/passes/visitor.rs @@ -59,6 +59,7 @@ pub trait Visitor { Node::Vector(v) => self.visit_vector(graph, v.clone().into()), Node::Matrix(m) => self.visit_matrix(graph, m.clone().into()), Node::Accessor(a) => self.visit_accessor(graph, a.clone().into()), + Node::BusOp(b) => self.visit_bus_op(graph, b.clone().into()), Node::Parameter(p) => self.visit_parameter(graph, p.clone().into()), Node::Value(v) => self.visit_value(graph, v.clone().into()), Node::None(_) => Ok(()), @@ -140,6 +141,10 @@ pub trait Visitor { ) -> Result<(), CompileError> { Ok(()) } + /// Visit a BusOp node + fn visit_bus_op(&mut self, _graph: &mut Graph, _bus_op: Link) -> Result<(), CompileError> { + Ok(()) + } /// Visit a Parameter node fn visit_parameter( &mut self, diff --git a/mir/src/tests/buses.rs b/mir/src/tests/buses.rs new file mode 100644 index 000000000..ef666bc00 --- /dev/null +++ b/mir/src/tests/buses.rs @@ -0,0 +1,207 @@ +use crate::{ + ir::{assert_bus_eq, Add, Builder, Bus, Fold, FoldOperator, Link, Mir, Op, Vector}, + tests::translate, +}; +use air_parser::{ast, Symbol}; +use miden_diagnostics::SourceSpan; + +use super::{compile, expect_diagnostic}; + +#[test] +fn buses_in_boundary_constraints() { + let source = " + def test + + trace_columns { + main: [a], + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf p.first = null; + enf q.first = null; + enf p.last = null; + enf q.last = null; + # TODO: to be used when we have support for variable-length public inputs + #enf p.last = inputs; + #enf q.last = inputs; + } + + integrity_constraints { + enf a = 0; + }"; + assert!(compile(source).is_ok()); +} + +#[test] +fn buses_in_integrity_constraints() { + let source = " + def test + + trace_columns { + main: [a], + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf p.first = null; + enf q.first = null; + enf p.last = null; + enf q.last = null; + # TODO: to be used when we have support for variable-length public inputs + #enf p.last = inputs; + #enf q.last = inputs; + } + + integrity_constraints { + p.add(1) when 1; + p.rem(1) when 1; + q.add(1, 2) when 1; + q.add(1, 2) when 1; + q.rem(1, 2) for 2; + }"; + + assert!(compile(source).is_ok()); +} + +#[test] +fn buses_args_expr_in_integrity_expr() { + let source = " + def test + + trace_columns { + main: [a], + } + + public_inputs { + inputs: [2], + } + + buses { + unit p, + } + + boundary_constraints { + enf p.first = null; + } + + integrity_constraints { + let vec = [x for x in 0..3]; + let b = 41; + let x = sum(vec) + b; + p.add(x) when 1; + p.rem(x) when 0; + }"; + assert!(compile(source).is_ok()); + let mut result_mir = translate(source).unwrap(); + let bus = Bus::create( + ast::Identifier::new(SourceSpan::default(), Symbol::new(0)), + ast::BusType::Unit, + SourceSpan::default(), + ); + let vec_op = Vector::builder() + .size(3) + .elements(From::from(0)) + .elements(From::from(1)) + .elements(From::from(2)) + .span(SourceSpan::default()) + .build(); + let b: Link = From::from(41); + let vec_sum = Fold::builder() + .iterator(vec_op) + .operator(FoldOperator::Add) + .initial_value(From::from(0)) + .span(SourceSpan::default()) + .build(); + let x: Link = Add::builder() + .lhs(vec_sum) + .rhs(b.clone()) + .span(SourceSpan::default()) + .build(); + let sel: Link = From::from(1); + let _p_add = bus.add(&[x.clone()], sel.clone(), SourceSpan::default()); + let not_sel: Link = From::from(0); + let _p_rem = bus.rem(&[x.clone()], not_sel.clone(), SourceSpan::default()); + let bus_ident = result_mir.constraint_graph().buses.keys().next().unwrap(); + let mut expected_mir = Mir::new(result_mir.name); + let _ = expected_mir + .constraint_graph_mut() + .insert_bus(*bus_ident, bus.clone().clone()); + assert_bus_eq(&mut expected_mir, &mut result_mir); +} + +// Tests that should return errors +#[test] +fn err_buses_boundaries_to_const() { + let source = " + def test + + trace_columns { + main: [a], + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf p.first = 0; + enf q.last = null; + } + + integrity_constraints { + enf a = 0; + }"; + + expect_diagnostic(source, "error: invalid constraint"); +} + +#[test] +fn err_trace_columns_constrained_with_null() { + let source = " + def test + + trace_columns { + main: [a], + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf a.last = null; + } + + integrity_constraints { + enf a = 0; + }"; + + expect_diagnostic(source, "error: invalid constraint"); +} diff --git a/mir/src/tests/mod.rs b/mir/src/tests/mod.rs index b3053ba61..c5900e571 100644 --- a/mir/src/tests/mod.rs +++ b/mir/src/tests/mod.rs @@ -1,5 +1,6 @@ mod access; mod boundary_constraints; +mod buses; mod constant; mod evaluators; mod functions; @@ -51,6 +52,19 @@ pub fn translate(source: &str) -> Result { } } +#[allow(dead_code)] +pub fn parse(source: &str) -> Result { + let compiler = Compiler::default(); + match compiler.parse(source) { + Ok(ast) => Ok(ast), + Err(err) => { + compiler.diagnostics.emit(err); + compiler.emitter.print_captured_to_stderr(); + Err(()) + } + } +} + #[track_caller] pub fn expect_diagnostic(source: &str, expected: &str) { let compiler = Compiler::default(); @@ -112,7 +126,8 @@ impl Compiler { air_parser::transforms::ConstantPropagation::new(&self.diagnostics) .chain(crate::passes::AstToMir::new(&self.diagnostics)) .chain(crate::passes::Inlining::new(&self.diagnostics)) - .chain(crate::passes::Unrolling::new(&self.diagnostics)); + .chain(crate::passes::Unrolling::new(&self.diagnostics)) + .chain(crate::passes::BusOpExpand::new(&self.diagnostics)); pipeline.run(ast) }) } @@ -126,6 +141,11 @@ impl Compiler { pipeline.run(ast) }) } + #[allow(dead_code)] + pub fn parse(&self, source: &str) -> Result { + air_parser::parse(&self.diagnostics, self.codemap.clone(), source) + .map_err(CompileError::Parse) + } } struct SplitEmitter { diff --git a/mir/src/tests/trace.rs b/mir/src/tests/trace.rs index a8beddd46..c31d9cb5e 100644 --- a/mir/src/tests/trace.rs +++ b/mir/src/tests/trace.rs @@ -64,7 +64,7 @@ fn err_bc_column_undeclared() { enf clk' = clk + 1; }"; - expect_diagnostic(source, "this variable is not defined"); + expect_diagnostic(source, "this variable / bus is not defined"); } #[test] @@ -84,7 +84,7 @@ fn err_ic_column_undeclared() { enf clk' = clk + 1; }"; - expect_diagnostic(source, "this variable is not defined"); + expect_diagnostic(source, "this variable / bus is not defined"); } #[test] diff --git a/mir/src/tests/variables.rs b/mir/src/tests/variables.rs index e9eaa3cd3..421cd1bde 100644 --- a/mir/src/tests/variables.rs +++ b/mir/src/tests/variables.rs @@ -183,7 +183,7 @@ fn invalid_matrix_literal_with_leading_vector_binding() { enf clk' = d[0][0]; }"; - expect_diagnostic(source, "expected one of: '\"!\"', '\"(\"', 'decl_ident_ref', 'function_identifier', 'identifier', 'int'"); + expect_diagnostic(source, "expected one of: '\"!\"', '\"(\"', '\"null\"', 'decl_ident_ref', 'function_identifier', 'identifier', 'int'"); } #[test] @@ -231,7 +231,7 @@ fn invalid_variable_access_before_declaration() { enf clk' = clk + 1; }"; - expect_diagnostic(source, "this variable is not defined"); + expect_diagnostic(source, "this variable / bus is not defined"); } #[test] @@ -277,7 +277,7 @@ fn invalid_reference_to_variable_defined_in_other_section() { enf clk' = clk + a; }"; - expect_diagnostic(source, "this variable is not defined"); + expect_diagnostic(source, "this variable / bus is not defined"); } #[test] diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index 3fffe9115..134fc2588 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -35,6 +35,8 @@ use super::*; pub enum Declaration { /// Import one or more items from the specified AirScript module to the current module Import(Span), + /// A Bus section declaration + Buses(Span>), /// A constant value declaration Constant(Constant), /// An evaluator function definition @@ -76,6 +78,56 @@ pub enum Declaration { IntegrityConstraints(Span>), } +/// Represents a bus declaration in an AirScript module. +#[derive(Debug, Clone, Spanned)] +pub struct Bus { + #[span] + pub span: SourceSpan, + pub name: Identifier, + pub bus_type: BusType, +} +impl Bus { + /// Creates a new bus declaration + pub const fn new(span: SourceSpan, name: Identifier, bus_type: BusType) -> Self { + Self { + span, + name, + bus_type, + } + } +} +#[derive(Default, Hash, Debug, Clone, PartialEq, Eq)] +pub enum BusType { + /// A multiset bus + #[default] + Unit, + /// A logup bus + Mult, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BusOperator { + /// Add a tuple to the bus + Add, + /// Remove a tuple from the bus + Rem, +} +impl std::fmt::Display for BusOperator { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + Self::Add => write!(f, "add"), + Self::Rem => write!(f, "rem"), + } + } +} + +impl Eq for Bus {} +impl PartialEq for Bus { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.bus_type == other.bus_type + } +} + /// Stores a constant's name and value. There are three types of constants: /// /// * Scalar: 123 diff --git a/parser/src/ast/display.rs b/parser/src/ast/display.rs index 5cc5a7823..7d5f95d62 100644 --- a/parser/src/ast/display.rs +++ b/parser/src/ast/display.rs @@ -111,6 +111,7 @@ impl fmt::Display for DisplayStatement<'_> { write!(f, "enf {}", expr) } Statement::Expr(ref expr) => write!(f, "return {}", expr), + Statement::BusEnforce(ref expr) => write!(f, "bus_enf {}", expr), } } } diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index 7aa7a5aaf..4569d011f 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -22,7 +22,7 @@ pub type Range = std::ops::Range; /// Represents any type of identifier in AirScript #[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Spanned)] -pub struct Identifier(Span); +pub struct Identifier(pub Span); impl Identifier { pub fn new(span: SourceSpan, name: Symbol) -> Self { Self(Span::new(span, name)) @@ -91,11 +91,12 @@ impl From for Identifier { /// Represents an identifier qualified with its namespace. /// /// Identifiers in AirScript are separated into two namespaces: one for functions, -/// and one for bindings. This is because functions cannot be bound, and bindings -/// cannot be called, so we can always disambiguate identifiers based on its usage. +/// and one for buses and bindings. This is because functions cannot be bound, added to or remove from, +/// buses cannot be called, and bindings cannot be called either +/// So we can always disambiguate identifiers based on its usage. /// -/// It is still probably best practice to avoid having name conflicts between functions -/// and bindings, but that is a matter of style rather than one of necessity. +/// It is still probably best practice to avoid having name conflicts between functions, +/// buses and bindings, but that is a matter of style rather than one of necessity. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Spanned)] pub enum NamespacedIdentifier { Function(#[span] Identifier), @@ -301,6 +302,10 @@ pub enum Expr { /// NOTE: The AirScript syntax only permits `let` in statement position, so this variant /// is only present in the AST as the result of an explicit transformation. Let(Box), + /// A bus operation (`p.add(...)` or `p.rem(...)`) + BusOperation(BusOperation), + /// An empty bus + Null(Span<()>), } impl Expr { /// Returns true if this expression is constant @@ -335,6 +340,7 @@ impl Expr { Self::Call(ref call) => call.ty, Self::ListComprehension(ref lc) => lc.ty, Self::Let(ref let_expr) => let_expr.ty(), + Self::BusOperation(_) | Self::Null(_) => todo!(), } } } @@ -352,6 +358,8 @@ impl fmt::Debug for Expr { f.debug_tuple("ListComprehension").field(expr).finish() } Self::Let(ref let_expr) => write!(f, "{let_expr:#?}"), + Self::BusOperation(ref expr) => f.debug_tuple("BusOp").field(expr).finish(), + Self::Null(ref expr) => f.debug_tuple("Null").field(expr).finish(), } } } @@ -383,6 +391,8 @@ impl fmt::Display for Expr { }; write!(f, "{display}") } + Self::BusOperation(ref expr) => write!(f, "{}", expr), + Self::Null(ref _expr) => write!(f, "null"), } } } @@ -404,6 +414,12 @@ impl From for Expr { Self::Call(expr) } } +impl From for Expr { + #[inline] + fn from(expr: BusOperation) -> Self { + Self::BusOperation(expr) + } +} impl From for Expr { #[inline] fn from(expr: ListComprehension) -> Self { @@ -438,6 +454,8 @@ impl TryFrom for Expr { Err(InvalidExprError::BoundedSymbolAccess(expr.span())) } ScalarExpr::Let(expr) => Ok(Self::Let(expr)), + ScalarExpr::BusOperation(expr) => Ok(Self::BusOperation(expr)), + ScalarExpr::Null(spanned) => Ok(Self::Null(spanned)), } } } @@ -486,6 +504,10 @@ pub enum ScalarExpr { /// binary expressions or function calls to a block of statements, and only when the result /// of evaluating the `let` produces a valid scalar expression. Let(Box), + /// A bus operation + BusOperation(BusOperation), + /// An empty bus + Null(Span<()>), } impl ScalarExpr { /// Returns true if this is a constant value @@ -520,6 +542,7 @@ impl ScalarExpr { }, Self::Call(ref expr) => Ok(expr.ty), Self::Let(ref expr) => Ok(expr.ty()), + Self::BusOperation(_) | ScalarExpr::Null(_) => todo!(), } } } @@ -576,6 +599,8 @@ impl fmt::Debug for ScalarExpr { Self::Binary(ref expr) => f.debug_tuple("Binary").field(expr).finish(), Self::Call(ref expr) => f.debug_tuple("Call").field(expr).finish(), Self::Let(ref expr) => write!(f, "{:#?}", expr), + Self::BusOperation(ref expr) => f.debug_tuple("BusOp").field(expr).finish(), + Self::Null(ref expr) => f.debug_tuple("Null").field(expr).finish(), } } } @@ -595,6 +620,8 @@ impl fmt::Display for ScalarExpr { }; write!(f, "{display}") } + Self::BusOperation(ref expr) => write!(f, "{}", expr), + Self::Null(ref _value) => write!(f, "null"), } } } @@ -821,9 +848,10 @@ impl fmt::Display for Boundary { } /// Represents the way an identifier is accessed/referenced in the source. -#[derive(Hash, Debug, Clone, Eq, PartialEq)] +#[derive(Hash, Debug, Clone, Eq, PartialEq, Default)] pub enum AccessType { /// Access refers to the entire bound value + #[default] Default, /// Access binds a sub-slice of a vector Slice(RangeExpr), @@ -1253,6 +1281,53 @@ impl fmt::Display for ListComprehension { } } +#[derive(Clone, Spanned)] +pub struct BusOperation { + #[span] + pub span: SourceSpan, + pub bus: ResolvableIdentifier, + pub op: BusOperator, + pub args: Vec, +} + +impl BusOperation { + pub fn new(span: SourceSpan, bus: Identifier, op: BusOperator, args: Vec) -> Self { + Self { + span, + bus: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Binding(bus)), + op, + args, + } + } +} + +impl Eq for BusOperation {} +impl PartialEq for BusOperation { + fn eq(&self, other: &Self) -> bool { + self.bus == other.bus && self.args == other.args && self.op == other.op + } +} +impl fmt::Debug for BusOperation { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("BusOperation") + .field("bus", &self.bus) + .field("op", &self.op) + .field("args", &self.args) + .finish() + } +} +impl fmt::Display for BusOperation { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!( + f, + "{}{}{}", + self.bus, + self.op, + DisplayTuple(self.args.as_slice()) + ) + } +} + /// Represents a function call (either a pure function or an evaluator). /// /// Calls are permitted in a scalar expression context, but arguments to the diff --git a/parser/src/ast/mod.rs b/parser/src/ast/mod.rs index b02b952c4..c61145a24 100644 --- a/parser/src/ast/mod.rs +++ b/parser/src/ast/mod.rs @@ -73,6 +73,8 @@ pub struct Program { pub evaluators: BTreeMap, /// The set of used pure functions referenced in this program. pub functions: BTreeMap, + /// The set of used buses referenced in this program. + pub buses: BTreeMap, /// The set of used periodic columns referenced in this program. pub periodic_columns: BTreeMap, /// The set of public inputs defined in the root module @@ -118,6 +120,7 @@ impl Program { constants: Default::default(), evaluators: Default::default(), functions: Default::default(), + buses: Default::default(), periodic_columns: Default::default(), public_inputs: Default::default(), random_values: None, @@ -234,6 +237,15 @@ impl Program { if let Some(ic) = root_module.integrity_constraints.as_ref() { program.integrity_constraints = ic.to_vec(); } + // Make sure we move the buses into the program + if !root_module.buses.is_empty() { + program.buses = BTreeMap::from_iter(root_module.buses.iter().map(|(k, v)| { + ( + QualifiedIdentifier::new(root, NamespacedIdentifier::Binding(*k)), + v.clone(), + ) + })); + } for evaluator in root_module.evaluators.values() { root_nodes.push_back(QualifiedIdentifier::new( root, diff --git a/parser/src/ast/module.rs b/parser/src/ast/module.rs index 2d306b360..d7e5fd5bd 100644 --- a/parser/src/ast/module.rs +++ b/parser/src/ast/module.rs @@ -59,6 +59,7 @@ pub struct Module { pub public_inputs: BTreeMap, pub random_values: Option, pub trace_columns: Vec, + pub buses: BTreeMap, pub boundary_constraints: Option>>, pub integrity_constraints: Option>>, } @@ -81,6 +82,7 @@ impl Module { constants: Default::default(), evaluators: Default::default(), functions: Default::default(), + buses: Default::default(), periodic_columns: Default::default(), public_inputs: Default::default(), random_values: None, @@ -152,6 +154,11 @@ impl Module { Declaration::IntegrityConstraints(statements) => { module.declare_integrity_constraints(diagnostics, statements)?; } + Declaration::Buses(mut buses) => { + for bus in buses.drain(..) { + module.declare_bus(diagnostics, &mut names, bus)?; + } + } } } @@ -416,6 +423,22 @@ impl Module { Ok(()) } + fn declare_bus( + &mut self, + diagnostics: &DiagnosticsHandler, + names: &mut HashSet, + bus: Bus, + ) -> Result<(), SemanticAnalysisError> { + if let Some(prev) = names.replace(NamespacedIdentifier::Binding(bus.name)) { + conflicting_declaration(diagnostics, "bus", prev.span(), bus.name.span()); + return Err(SemanticAnalysisError::NameConflict(bus.name.span())); + } + + self.buses.insert(bus.name, bus); + + Ok(()) + } + fn declare_periodic_column( &mut self, diagnostics: &DiagnosticsHandler, diff --git a/parser/src/ast/statement.rs b/parser/src/ast/statement.rs index 3e13723c9..ceb94d08f 100644 --- a/parser/src/ast/statement.rs +++ b/parser/src/ast/statement.rs @@ -60,6 +60,8 @@ pub enum Statement { /// Just like `Enforce`, except the constraint is contained in the body of a list comprehension, /// and must be enforced on every value produced by that comprehension. EnforceAll(ListComprehension), + /// Declares a bus related constraint + BusEnforce(ListComprehension), } impl Statement { /// Checks this statement to see if it contains any constraints @@ -69,7 +71,10 @@ impl Statement { /// one or more constraints in its body. pub fn has_constraints(&self) -> bool { match self { - Self::Enforce(_) | Self::EnforceIf(_, _) | Self::EnforceAll(_) => true, + Self::Enforce(_) + | Self::EnforceIf(_, _) + | Self::EnforceAll(_) + | Self::BusEnforce(_) => true, Self::Let(Let { body, .. }) => body.iter().any(|s| s.has_constraints()), Self::Expr(_) => false, } @@ -161,9 +166,10 @@ impl Let { last = let_expr.body.last(); } Statement::Expr(ref expr) => return expr.ty(), - Statement::Enforce(_) | Statement::EnforceIf(_, _) | Statement::EnforceAll(_) => { - break - } + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => break, } } diff --git a/parser/src/ast/visit.rs b/parser/src/ast/visit.rs index 9876f3296..f0dcd2e3a 100644 --- a/parser/src/ast/visit.rs +++ b/parser/src/ast/visit.rs @@ -125,6 +125,9 @@ pub trait VisitMut { fn visit_mut_function(&mut self, expr: &mut ast::Function) -> ControlFlow { visit_mut_function(self, expr) } + fn visit_mut_bus(&mut self, expr: &mut ast::Bus) -> ControlFlow { + visit_mut_bus(self, expr) + } fn visit_mut_periodic_column(&mut self, expr: &mut ast::PeriodicColumn) -> ControlFlow { visit_mut_periodic_column(self, expr) } @@ -184,6 +187,9 @@ pub trait VisitMut { fn visit_mut_enforce_all(&mut self, expr: &mut ast::ListComprehension) -> ControlFlow { self.visit_mut_list_comprehension(expr) } + fn visit_mut_bus_enforce(&mut self, expr: &mut ast::ListComprehension) -> ControlFlow { + self.visit_mut_list_comprehension(expr) + } fn visit_mut_integrity_constraints( &mut self, exprs: &mut Vec, @@ -208,6 +214,9 @@ pub trait VisitMut { fn visit_mut_call(&mut self, expr: &mut ast::Call) -> ControlFlow { visit_mut_call(self, expr) } + fn visit_mut_bus_operation(&mut self, expr: &mut ast::BusOperation) -> ControlFlow { + visit_mut_bus_operation(self, expr) + } fn visit_mut_range_bound(&mut self, expr: &mut ast::RangeBound) -> ControlFlow { visit_mut_range_bound(self, expr) } @@ -268,6 +277,9 @@ where fn visit_mut_function(&mut self, expr: &mut ast::Function) -> ControlFlow { (**self).visit_mut_function(expr) } + fn visit_mut_bus(&mut self, expr: &mut ast::Bus) -> ControlFlow { + (**self).visit_mut_bus(expr) + } fn visit_mut_periodic_column(&mut self, expr: &mut ast::PeriodicColumn) -> ControlFlow { (**self).visit_mut_periodic_column(expr) } @@ -332,6 +344,9 @@ where fn visit_mut_enforce_all(&mut self, expr: &mut ast::ListComprehension) -> ControlFlow { (**self).visit_mut_enforce_all(expr) } + fn visit_mut_bus_enforce(&mut self, expr: &mut ast::ListComprehension) -> ControlFlow { + (**self).visit_mut_bus_enforce(expr) + } fn visit_mut_expr(&mut self, expr: &mut ast::Expr) -> ControlFlow { (**self).visit_mut_expr(expr) } @@ -350,6 +365,9 @@ where fn visit_mut_call(&mut self, expr: &mut ast::Call) -> ControlFlow { (**self).visit_mut_call(expr) } + fn visit_mut_bus_operation(&mut self, expr: &mut ast::BusOperation) -> ControlFlow { + (**self).visit_mut_bus_operation(expr) + } fn visit_mut_range_bound(&mut self, expr: &mut ast::RangeBound) -> ControlFlow { (**self).visit_mut_range_bound(expr) } @@ -404,6 +422,9 @@ where for function in module.functions.values_mut() { visitor.visit_mut_function(function)?; } + for bus in module.buses.values_mut() { + visitor.visit_mut_bus(bus)?; + } for column in module.periodic_columns.values_mut() { visitor.visit_mut_periodic_column(column)?; } @@ -495,6 +516,13 @@ where visitor.visit_mut_statement_block(&mut expr.body) } +pub fn visit_mut_bus(visitor: &mut V, expr: &mut ast::Bus) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + visitor.visit_mut_identifier(&mut expr.name) +} + pub fn visit_mut_evaluator_trace_segment( visitor: &mut V, expr: &mut ast::TraceSegment, @@ -586,6 +614,7 @@ where } ast::Statement::EnforceAll(ref mut expr) => visitor.visit_mut_enforce_all(expr), ast::Statement::Expr(ref mut expr) => visitor.visit_mut_expr(expr), + ast::Statement::BusEnforce(ref mut expr) => visitor.visit_mut_bus_enforce(expr), } } @@ -631,6 +660,8 @@ where ast::Expr::Call(ref mut expr) => visitor.visit_mut_call(expr), ast::Expr::ListComprehension(ref mut expr) => visitor.visit_mut_list_comprehension(expr), ast::Expr::Let(ref mut expr) => visitor.visit_mut_let(expr), + ast::Expr::BusOperation(ref mut expr) => visitor.visit_mut_bus_operation(expr), + ast::Expr::Null(_) => ControlFlow::Continue(()), } } @@ -639,7 +670,7 @@ where V: ?Sized + VisitMut, { match expr { - ast::ScalarExpr::Const(_) => ControlFlow::Continue(()), + ast::ScalarExpr::Const(_) | ast::ScalarExpr::Null(_) => ControlFlow::Continue(()), ast::ScalarExpr::SymbolAccess(ref mut expr) => visitor.visit_mut_symbol_access(expr), ast::ScalarExpr::BoundedSymbolAccess(ref mut expr) => { visitor.visit_mut_bounded_symbol_access(expr) @@ -647,6 +678,7 @@ where ast::ScalarExpr::Binary(ref mut expr) => visitor.visit_mut_binary_expr(expr), ast::ScalarExpr::Call(ref mut expr) => visitor.visit_mut_call(expr), ast::ScalarExpr::Let(ref mut expr) => visitor.visit_mut_let(expr), + ast::ScalarExpr::BusOperation(ref mut expr) => visitor.visit_mut_bus_operation(expr), } } @@ -688,6 +720,20 @@ where ControlFlow::Continue(()) } +pub fn visit_mut_bus_operation( + visitor: &mut V, + expr: &mut ast::BusOperation, +) -> ControlFlow +where + V: ?Sized + VisitMut, +{ + visitor.visit_mut_resolvable_identifier(&mut expr.bus)?; + for arg in expr.args.iter_mut() { + visitor.visit_mut_expr(arg)?; + } + ControlFlow::Continue(()) +} + pub fn visit_mut_range_bound(visitor: &mut V, expr: &mut ast::RangeBound) -> ControlFlow where V: ?Sized + VisitMut, diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index ee94c9f28..4e7b79404 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -116,6 +116,21 @@ pub enum Token { /// Keyword to declare the function section in the AIR constraints module. Fn, + // BUSES KEYWORDS + // -------------------------------------------------------------------------------------------- + /// Marks the beginning of buses section in the constraints file. + Buses, + /// Used to represent a multiset bus declaration. + Unit, + /// Used to represent a logup bus declaration. + Mult, + /// Used to represent an empty bus + Null, + /// Used to represent the addition of a given tuple to a bus + Add, + /// Used to represent the removal of a given tuple to a bus + Rem, + // BOUNDARY CONSTRAINT KEYWORDS // -------------------------------------------------------------------------------------------- /// Marks the beginning of boundary constraints section in the constraints file. @@ -187,6 +202,12 @@ impl Token { "ev" => Self::Ev, "fn" => Self::Fn, "felt" => Self::Felt, + "buses" => Self::Buses, + "unit" => Self::Unit, + "mult" => Self::Mult, + "null" => Self::Null, + "add" => Self::Add, + "rem" => Self::Rem, "boundary_constraints" => Self::BoundaryConstraints, "integrity_constraints" => Self::IntegrityConstraints, "first" => Self::First, @@ -260,6 +281,12 @@ impl fmt::Display for Token { Self::Ev => write!(f, "ev"), Self::Fn => write!(f, "fn"), Self::Felt => write!(f, "felt"), + Self::Buses => write!(f, "buses"), + Self::Unit => write!(f, "unit"), + Self::Mult => write!(f, "mult"), + Self::Null => write!(f, "null"), + Self::Add => write!(f, "add"), + Self::Rem => write!(f, "rem"), Self::BoundaryConstraints => write!(f, "boundary_constraints"), Self::First => write!(f, "first"), Self::Last => write!(f, "last"), @@ -578,6 +605,7 @@ where let next = self.read(); match Token::from_keyword_or_ident(self.slice()) { Token::Ident(id) if next == '(' => Token::FunctionIdent(id), + //Token::Ident(id) if next == '.' => Token::BusIdent(id), token => token, } } diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index c9ac9ad27..245660384 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -76,6 +76,7 @@ Declaration: Declaration = { RandomValues => Declaration::RandomValues(<>), EvaluatorFunction => Declaration::EvaluatorFunction(<>), Function => Declaration::Function(<>), + Buses => Declaration::Buses(<>), => Declaration::Trace(Span::new(span!(l, r), trace)), => Declaration::PublicInputs(<>), => Declaration::BoundaryConstraints(<>), @@ -180,6 +181,26 @@ PeriodicColumn: PeriodicColumn = { => PeriodicColumn::new(span!(l, r), name, values), } + +// BUSES +// ================================================================================================ + +// Buses are not required, and there is no limit to the number that can be provided. +Buses: Span> = { + "buses" "{" "}" + => Span::new(span!(l, r), bus) +} + +Bus: Bus = { + "," + => Bus::new(span!(l, r), name, bus_type), +} + +BusType: BusType = { + "unit" => BusType::Unit, + "mult" => BusType::Mult, +} + // RANDOM VALUES // ================================================================================================ @@ -340,6 +361,7 @@ ConstraintStatements: Vec = { ConstraintStatement: Vec = { "enf" "match" "{" "}" ";" => <>, "enf" ";" => vec![<>], + ";" => vec![<>], } ReturnStatement: Expr = { @@ -405,15 +427,60 @@ ConstraintExpr: Statement = { } } +// 1. `p.add(a, b) when s` +// 2. `q.add(a, b) for m` +BusConstraintExpr: Statement = { + => { + let generated_name = format!("%{}", *next_var); + *next_var += 1; + let generated_binding = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern(generated_name)); + let context = vec![(generated_binding, Expr::Range(RangeExpr::from(0..1)))]; + Statement::BusEnforce(ListComprehension::new(span!(l, r), expr, context, Some(selector))) + }, + => { + let generated_name = format!("%{}", *next_var); + *next_var += 1; + let generated_binding = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern(generated_name)); + let context = vec![(generated_binding, Expr::Range(RangeExpr::from(0..1)))]; + Statement::BusEnforce(ListComprehension::new(span!(l, r), expr, context, Some(multiplicity))) + } +} + +// 1. `p.first = null` +// 2. `q.first = ???` To define +BoundaryBusConstraintExpr: Statement = { + => { + let generated_name = format!("%{}", *next_var); + *next_var += 1; + let generated_binding = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern(generated_name)); + let context = vec![(generated_binding, Expr::Range(RangeExpr::from(0..1)))]; + Statement::BusEnforce(ListComprehension::new(span!(l, r), expr, context, None)) + } +} + ScalarConstraintExpr: ScalarExpr = { FunctionCall, "=" => ScalarExpr::Binary(BinaryExpr::new(span!(l, r), BinaryOp::Eq, lhs, rhs)), } +ScalarBusConstraintExpr: ScalarExpr = { + "." "(" > ")" => { + ScalarExpr::BusOperation(BusOperation::new(span!(l, r), bus, bus_operator, args)) + } +} + +BusOperator: BusOperator = { + "add" => BusOperator::Add, + "rem" => BusOperator::Rem +} + WithSelector: ScalarExpr = { "when" , } +WithMultiplicity: ScalarExpr = { + "for" , +} Expr: Expr = { =>? { @@ -450,6 +517,7 @@ ScalarExprBase: ScalarExpr = { #[precedence(level="0")] SymbolAccess, => ScalarExpr::Const(<>), + "null" => ScalarExpr::Null(Span::new(span!(l, r), ())), "(" ")", #[precedence(level="1")] @@ -646,6 +714,12 @@ extern { "public_inputs" => Token::PublicInputs, "periodic_columns" => Token::PeriodicColumns, "random_values" => Token::RandomValues, + "buses" => Token::Buses, + "unit" => Token::Unit, + "mult" => Token::Mult, + "null" => Token::Null, + "add" => Token::Add, + "rem" => Token::Rem, "boundary_constraints" => Token::BoundaryConstraints, "first" => Token::First, "last" => Token::Last, diff --git a/parser/src/parser/tests/boundary_constraints.rs b/parser/src/parser/tests/boundary_constraints.rs index 08ea2bb39..ecadd7b00 100644 --- a/parser/src/parser/tests/boundary_constraints.rs +++ b/parser/src/parser/tests/boundary_constraints.rs @@ -14,6 +14,11 @@ trace_columns { main: [clk], } +buses { + unit p, + mult q, +} + public_inputs { inputs: [2], } @@ -32,6 +37,11 @@ integrity_constraints { /// main: [clk] /// } /// +/// buses { +/// unit p, +/// mult q, +/// } +/// /// public_inputs { /// inputs: [2] /// } @@ -51,6 +61,14 @@ fn test_module() -> Module { ident!(inputs), PublicInput::new(SourceSpan::UNKNOWN, ident!(inputs), 2), ); + expected.buses.insert( + ident!(p), + Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Unit), + ); + expected.buses.insert( + ident!(q), + Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Mult), + ); expected.integrity_constraints = Some(Span::new( SourceSpan::UNKNOWN, vec![enforce!(eq!(access!(clk), int!(0)))], @@ -102,6 +120,29 @@ fn boundary_constraint_at_last() { ParseTest::new().expect_module_ast(&source, expected); } +#[test] +fn boundary_constraint_with_buses() { + let source = format!( + " + {BASE_MODULE} + + boundary_constraints {{ + enf p.first = null; + enf q.last = null; + }}" + ); + + let mut expected = test_module(); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![ + enforce!(eq!(bounded_access!(p, Boundary::First), null!())), + enforce!(eq!(bounded_access!(q, Boundary::Last), null!())), + ], + )); + ParseTest::new().expect_module_ast(&source, expected); +} + #[test] fn error_invalid_boundary() { let source = format!( diff --git a/parser/src/parser/tests/buses.rs b/parser/src/parser/tests/buses.rs new file mode 100644 index 000000000..dd24b8ea7 --- /dev/null +++ b/parser/src/parser/tests/buses.rs @@ -0,0 +1,95 @@ +use miden_diagnostics::SourceSpan; + +use crate::ast::*; + +use super::ParseTest; + +#[test] +fn buses() { + let source = " + mod test + + buses { + unit p, + mult q, + }"; + + let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.buses.insert( + ident!(p), + Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Unit), + ); + expected.buses.insert( + ident!(q), + Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Mult), + ); + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn empty_buses() { + let source = " + mod test + + buses{}"; + + let expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + ParseTest::new().expect_module_ast(source, expected); +} + +#[test] +fn boundary_constraints_buses() { + let _source = " + mod test + + buses { + unit p, + mult q, + } + + boundary_constraints { + p.first = null; + q.last = null; + }"; + + /*let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.buses.insert( + ident!(p), + Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Unit), + ); + expected.buses.insert( + ident!(q), + Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Mult), + ); + ParseTest::new().expect_module_ast(source, expected);*/ +} + +#[test] +fn integrity_constraints_buses() { + let _source = " + mod test + + buses { + unit p, + mult q, + } + + integrity_constraints { + p.add(1) when 1; + p.rem(1) when 1; + q.add(1, 2) when 1; + q.add(1, 2) when 1; + q.rem(1, 2) for 2; + }"; + + /*let mut expected = Module::new(ModuleType::Library, SourceSpan::UNKNOWN, ident!(test)); + expected.buses.insert( + ident!(p), + Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Unit), + ); + expected.buses.insert( + ident!(q), + Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Mult), + ); + ParseTest::new().expect_module_ast(source, expected);*/ +} diff --git a/parser/src/parser/tests/integrity_constraints.rs b/parser/src/parser/tests/integrity_constraints.rs index 85b754e35..344ce577a 100644 --- a/parser/src/parser/tests/integrity_constraints.rs +++ b/parser/src/parser/tests/integrity_constraints.rs @@ -50,6 +50,98 @@ fn integrity_constraints() { ParseTest::new().expect_module_ast(source, expected); } +#[test] +fn integrity_constraints_with_buses() { + let source = " + def test + + trace_columns { + main: [clk], + } + + buses { + unit p, + mult q, + } + + public_inputs { + inputs: [2], + } + + boundary_constraints { + enf p.first = null; + enf q.last = null; + } + + integrity_constraints { + p.add(1) when 1; + p.rem(1) when 1; + q.add(1, 2) when 1; + q.add(1, 2) when 1; + q.rem(1, 2) for 2; + }"; + + let mut expected = Module::new(ModuleType::Root, SourceSpan::UNKNOWN, ident!(test)); + expected + .trace_columns + .push(trace_segment!(0, "$main", [(clk, 1)])); + expected.public_inputs.insert( + ident!(inputs), + PublicInput::new(SourceSpan::UNKNOWN, ident!(inputs), 2), + ); + expected.boundary_constraints = Some(Span::new( + SourceSpan::UNKNOWN, + vec![ + enforce!(eq!(bounded_access!(p, Boundary::First), null!())), + enforce!(eq!(bounded_access!(q, Boundary::Last), null!())), + ], + )); + expected.buses.insert( + ident!(p), + Bus::new(SourceSpan::UNKNOWN, ident!(p), BusType::Unit), + ); + expected.buses.insert( + ident!(q), + Bus::new(SourceSpan::UNKNOWN, ident!(q), BusType::Mult), + ); + + let mut bus_enforces = Vec::new(); + + // p.add(1) when 1; + bus_enforces.push(bus_enforce!(lc!( + (("%0", range!(0..1))) => + bus_add!(p, vec![expr!(int!(1))]), + when int!(1)))); + + //p.rem(1) when 1; + bus_enforces.push(bus_enforce!(lc!( + (("%1", range!(0..1))) => + bus_rem!(p, vec![expr!(int!(1))]), + when int!(1)))); + + //q.add(1, 2) when 1; + bus_enforces.push(bus_enforce!(lc!( + (("%2", range!(0..1))) => + bus_add!(q, vec![expr!(int!(1)), expr!(int!(2))]), + when int!(1)))); + + //q.add(1, 2) when 1; + bus_enforces.push(bus_enforce!(lc!( + (("%3", range!(0..1))) => + bus_add!(q, vec![expr!(int!(1)), expr!(int!(2))]), + when int!(1)))); + + //q.rem(1, 2) for 2; + bus_enforces.push(bus_enforce!(lc!( + (("%4", range!(0..1))) => + bus_rem!(q, vec![expr!(int!(1)), expr!(int!(2))]), + for int!(2)))); + + expected.integrity_constraints = Some(Span::new(SourceSpan::UNKNOWN, bus_enforces)); + + ParseTest::new().expect_module_ast(source, expected); +} + #[test] fn err_integrity_constraints_invalid() { let source = " diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index 1011e2084..0e91177ab 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -381,6 +381,15 @@ macro_rules! int { }; } +macro_rules! null { + () => { + ScalarExpr::Null(miden_diagnostics::Span::new( + miden_diagnostics::SourceSpan::UNKNOWN, + (), + )) + }; +} + macro_rules! call { ($callee:ident ($($param:expr),+)) => { ScalarExpr::Call(Call::new(miden_diagnostics::SourceSpan::UNKNOWN, ident!($callee), vec![$($param),+])) @@ -482,6 +491,12 @@ macro_rules! enforce_all { }; } +macro_rules! bus_enforce { + ($expr:expr) => { + Statement::BusEnforce($expr) + }; +} + macro_rules! lc { (($(($binding:ident, $iterable:expr)),+) => $body:expr) => {{ let context = vec![ @@ -518,6 +533,25 @@ macro_rules! lc { ]; ListComprehension::new(miden_diagnostics::SourceSpan::UNKNOWN, $body, context, Some($selector)) }}; + + + (($(($binding:ident, $iterable:expr)),*) => $body:expr, for $selector:expr) => {{ + let context = vec![ + $( + (ident!($binding), $iterable) + ),+ + ]; + ListComprehension::new(miden_diagnostics::SourceSpan::UNKNOWN, $body, context, Some($selector)) + }}; + + (($(($binding:literal, $iterable:expr)),*) => $body:expr, for $selector:expr) => {{ + let context = vec![ + $( + (ident!($binding), $iterable) + ),+ + ]; + ListComprehension::new(miden_diagnostics::SourceSpan::UNKNOWN, $body, context, Some($selector)) + }}; } macro_rules! range { @@ -562,6 +596,28 @@ macro_rules! eq { }; } +macro_rules! bus_add { + ($bus:ident, $expr:expr) => { + ScalarExpr::BusOperation(BusOperation::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($bus), + BusOperator::Add, + $expr, + )) + }; +} + +macro_rules! bus_rem { + ($bus:ident, $rhs:expr) => { + ScalarExpr::BusOperation(BusOperation::new( + miden_diagnostics::SourceSpan::UNKNOWN, + ident!($bus), + BusOperator::Rem, + $rhs, + )) + }; +} + macro_rules! add { ($lhs:expr, $rhs:expr) => { ScalarExpr::Binary(BinaryExpr::new( @@ -627,6 +683,7 @@ macro_rules! import { mod arithmetic_ops; mod boundary_constraints; +mod buses; mod calls; mod constant_propagation; mod constants; diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index db3893780..d763d9c32 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -1,4 +1,6 @@ -use crate::ast::{AccessType, FunctionType, InvalidAccessError, RandBinding, TraceBinding, Type}; +use crate::ast::{ + AccessType, BusType, FunctionType, InvalidAccessError, RandBinding, TraceBinding, Type, +}; use std::fmt; /// This type provides type and contextual information about a binding, @@ -19,6 +21,8 @@ pub enum BindingType { /// /// The result type is None if the function is an evaluator Function(FunctionType), + /// A binding to a bus definition + Bus(BusType), /// A function parameter corresponding to trace columns TraceParam(TraceBinding), /// A direct reference to one or more contiguous trace columns @@ -43,6 +47,7 @@ impl BindingType { Self::Local(ty) | Self::Constant(ty) | Self::PublicInput(ty) => Some(*ty), Self::PeriodicColumn(_) => Some(Type::Felt), Self::Function(ty) => ty.result(), + Self::Bus(_) => Some(Type::Felt), } } @@ -218,6 +223,8 @@ impl BindingType { _ => Err(InvalidAccessError::IndexIntoScalar), }, Self::Function(_) => Err(InvalidAccessError::InvalidBinding), + // TODO: FIXME + Self::Bus(bus) => Ok(Self::Bus(bus.clone())), } } } @@ -233,6 +240,7 @@ impl fmt::Display for BindingType { Self::RandomValue(_) => f.write_str("random value(s)"), Self::PublicInput(_) => f.write_str("public input(s)"), Self::PeriodicColumn(_) => f.write_str("periodic column(s)"), + Self::Bus(_) => f.write_str("bus"), } } } diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index f0079b5bb..3586541f6 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -280,6 +280,20 @@ impl VisitMut for SemanticAnalysis<'_> { ); } + // Next, we add all buses to the set of local bindings. + // Buses are in their own namespace, but may conflict with imported items + for (bus_name, bus) in module.buses.iter() { + let namespaced_name = NamespacedIdentifier::Binding(*bus_name); + if let Some((prev, _)) = self.imported.get_key_value(&namespaced_name) { + self.declaration_import_conflict(namespaced_name.span(), prev.span())?; + } + assert_eq!( + self.locals + .insert(namespaced_name, BindingType::Bus(bus.bus_type.clone())), + None + ); + } + // Next, we add any periodic columns to the set of local bindings. // // These _can_ conflict with globally defined names, but are guaranteed not to conflict @@ -318,6 +332,10 @@ impl VisitMut for SemanticAnalysis<'_> { self.visit_mut_function(function)?; } + for bus in module.buses.values_mut() { + self.visit_mut_bus(bus)?; + } + if let Some(boundary_constraints) = module.boundary_constraints.as_mut() { if !boundary_constraints.is_empty() { self.visit_mut_boundary_constraints(boundary_constraints)?; @@ -436,6 +454,10 @@ impl VisitMut for SemanticAnalysis<'_> { ControlFlow::Continue(()) } + fn visit_mut_bus(&mut self, _bus: &mut Bus) -> ControlFlow { + ControlFlow::Continue(()) + } + fn visit_mut_boundary_constraints( &mut self, body: &mut Vec, @@ -506,6 +528,17 @@ impl VisitMut for SemanticAnalysis<'_> { result } + fn visit_mut_bus_enforce( + &mut self, + expr: &mut ListComprehension, + ) -> ControlFlow { + self.in_constraint_comprehension = true; + let result = self.visit_mut_list_comprehension(expr); + self.in_constraint_comprehension = false; + + result + } + fn visit_mut_let(&mut self, expr: &mut Let) -> ControlFlow { // Visit the binding expression first self.visit_mut_expr(&mut expr.value)?; @@ -847,7 +880,7 @@ impl VisitMut for SemanticAnalysis<'_> { .with_message("invalid expression") .with_primary_label( expr.span(), - "references to column boundaries are not permitted here", + "references to column / buses boundaries are not permitted here", ) .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) @@ -978,6 +1011,7 @@ impl VisitMut for SemanticAnalysis<'_> { ResolvableIdentifier::Unresolved(namespaced_id) => { // If locally defined, resolve it to the current module let namespaced_id = *namespaced_id; + if let Some(binding_ty) = self.locals.get(&namespaced_id) { match binding_ty { // This identifier is a local variable, alias to a declaration, or a function parameter @@ -992,7 +1026,8 @@ impl VisitMut for SemanticAnalysis<'_> { // These binding types are module-local declarations BindingType::Constant(_) | BindingType::Function(_) - | BindingType::PeriodicColumn(_) => { + | BindingType::PeriodicColumn(_) + | BindingType::Bus(_) => { *expr = ResolvableIdentifier::Resolved(QualifiedIdentifier::new( current_module, namespaced_id, @@ -1038,10 +1073,10 @@ impl VisitMut for SemanticAnalysis<'_> { NamespacedIdentifier::Binding(_) => { self.diagnostics .diagnostic(Severity::Error) - .with_message("reference to undefined variable") + .with_message("reference to undefined variable / bus") .with_primary_label( namespaced_id.span(), - "this variable is not defined", + "this variable / bus is not defined", ) .emit(); } @@ -1334,7 +1369,7 @@ impl SemanticAnalysis<'_> { // Visit the expression operands self.visit_mut_symbol_access(&mut access.column)?; - // Ensure the referenced symbol was a trace column, and that it produces a scalar value + // Ensure the referenced symbol was a trace column, and that it produces a scalar value, or a bus let (found, segment) = match self.resolvable_binding_type(&access.column.name) { Ok(ty) => match ty.item.access(access.column.access_type.clone()) { @@ -1353,6 +1388,10 @@ impl SemanticAnalysis<'_> { ); } } + Ok(BindingType::Bus(_)) => { + // Buses are valid in boundary constraints + (ty, 0) + } Ok(aty) => { let expected = BindingType::TraceColumn(TraceBinding::new( constraint_span, @@ -1378,37 +1417,116 @@ impl SemanticAnalysis<'_> { } }; - // Validate that the symbol access produces a scalar value - // - // If no type is known, a diagnostic is already emitted, so proceed as if it is valid - if let Some(ty) = access.column.ty.as_ref() { - if !ty.is_scalar() { - // Invalid constraint, only scalar values are allowed - self.type_mismatch( - Some(ty), + match (found.clone().item, expr.rhs.as_mut()) { + // Buses boundaries can be constrained by null, or soon by a public_input + (BindingType::Bus(_), ScalarExpr::Null(_)) => {} + (BindingType::Bus(_), ScalarExpr::SymbolAccess(access)) => { + self.has_type_errors = true; + self.invalid_constraint( access.span(), - &Type::Felt, - found.span(), - constraint_span, - )?; + "public input bus constraints are not yet supported", + ) + .with_secondary_label( + access.name.span(), + "this a reference to a public input", + ) + .emit(); + + // TODO: Update when table public inputs are supported + /*self.visit_mut_resolvable_identifier(&mut access.name)?; + self.visit_mut_access_type(&mut access.access_type)?; + + let resolved_binding_ty = + match self.resolvable_binding_type(&access.name) { + Ok(ty) => ty, + // An unresolved identifier at this point means that it is undefined, + // but we've already raised a diagnostic + // + // There is nothing useful we can do here other than continue traversing the module + // gathering as many undefined variable usages as possible before bailing + Err(_) => return ControlFlow::Continue(()), + }; + + match resolved_binding_ty.item { + BindingType::PublicInput(_) => {} + _ => { + self.has_type_errors = true; + self.invalid_constraint( + access.span(), + "expected a reference to a public input", + ) + .with_secondary_label( + access.name.span(), + "this is not a reference to a public input", + ) + .emit(); + } + }*/ + + // Buses boundaries can be constrained to null, nothing to do } - } + (BindingType::Bus(_), _) => { + // Buses cannot be constrained otherwise + self.has_type_errors = true; + self.invalid_constraint(expr.lhs.span(), "this constrains a bus") + .with_secondary_label( + expr.rhs.span(), + "but this expression is only valid to constrain columns", + ) + .with_note( + //"Only the null value or a public input is valid for constraining buses", + "Only the null value is valid for constraining buses", + ) + .emit(); + } + (_, ScalarExpr::Null(_)) => { + // Only buses can be constrained to null + self.has_type_errors = true; + self.invalid_constraint( + expr.lhs.span(), + "this constrains a column", + ) + .with_secondary_label( + expr.rhs.span(), + "but this expression is only valid to constrain buses", + ) + .with_note("The null value is only valid for defining empty buses") + .emit(); + } + _ => { + // Validate that the symbol access produces a scalar value + // + // If no type is known, a diagnostic is already emitted, so proceed as if it is valid + if let Some(ty) = access.column.ty.as_ref() { + if !ty.is_scalar() { + // Invalid constraint, only scalar values are allowed + self.type_mismatch( + Some(ty), + access.span(), + &Type::Felt, + found.span(), + constraint_span, + )?; + } + } - // Verify that the right-hand expression evaluates to a scalar - // - // The only way this is not the case, is if it is a a symbol access which produces an aggregate - self.visit_mut_scalar_expr(expr.rhs.as_mut())?; - if let ScalarExpr::SymbolAccess(access) = expr.rhs.as_ref() { - // Ensure this access produces a scalar, or if the type is unknown, assume it is valid - // because a diagnostic will have already been emitted - if !access.ty.as_ref().map(|t| t.is_scalar()).unwrap_or(true) { - self.type_mismatch( - access.ty.as_ref(), - access.span(), - &Type::Felt, - access.name.span(), - constraint_span, - )?; + // Verify that the right-hand expression evaluates to a scalar + // + // The only way this is not the case, is if it is a a symbol access which produces an aggregate + self.visit_mut_scalar_expr(expr.rhs.as_mut())?; + if let ScalarExpr::SymbolAccess(access) = expr.rhs.as_ref() { + // Ensure this access produces a scalar, or if the type is unknown, assume it is valid + // because a diagnostic will have already been emitted + if !access.ty.as_ref().map(|t| t.is_scalar()).unwrap_or(true) { + self.type_mismatch( + access.ty.as_ref(), + access.span(), + &Type::Felt, + access.name.span(), + constraint_span, + )?; + } + } } } @@ -1425,13 +1543,19 @@ impl SemanticAnalysis<'_> { ControlFlow::Continue(()) } other => { - self.invalid_constraint(other.span(), "expected this to be a reference to a trace column boundary, e.g. `a.first`") + self.invalid_constraint(other.span(), "expected this to be a reference to a trace column or bus boundary, e.g. `a.first`") .with_note("The given constraint is not a boundary constraint, and only boundary constraints are valid here.") .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) } } } + ScalarExpr::BusOperation(ref mut expr) => { + self.invalid_constraint(expr.span(), "expected an equality expression here") + .with_note("Bus operations are only permitted in integrity constraints") + .emit(); + ControlFlow::Break(SemanticAnalysisError::Invalid) + } ScalarExpr::Call(ref expr) => { self.invalid_constraint(expr.span(), "expected an equality expression here") .with_note( @@ -1461,7 +1585,7 @@ impl SemanticAnalysis<'_> { // However, we do need to validate two things: // // 1. That the constraint produces a scalar value - // 2. That the expression is either an equality, or a call to an evaluator function + // 2. That the expression is either an equality, a call to an evaluator function, or a bus operation // match expr { ScalarExpr::Binary(ref mut expr) if expr.op == BinaryOp::Eq => { @@ -1518,9 +1642,55 @@ impl SemanticAnalysis<'_> { ResolvableIdentifier::Unresolved(_) => ControlFlow::Continue(()), } } + ScalarExpr::BusOperation(ref mut expr) => { + // Visit the call normally, so we can resolve the callee identifier + self.visit_mut_bus_operation(expr)?; + + // Check that the call references an evaluator + // + // If unresolved, we've already raised a diagnostic for the invalid call + match expr.bus { + ResolvableIdentifier::Resolved(bus) => { + match bus.id() { + id @ NamespacedIdentifier::Binding(_) => { + match self.locals.get_key_value(&id) { + // Binding is to a local bus + Some((_, BindingType::Bus(_))) => ControlFlow::Continue(()), + Some((local_name, _)) => { + self.invalid_constraint(id.span(), "bus operations in constraints must be to bus") + .with_secondary_label(local_name.span(), "this function is not an evaluator") + .emit(); + ControlFlow::Break(SemanticAnalysisError::Invalid) + } + None => { + // If the bus was resolved, check it is of a bus + let (import_id, module_id) = self.imported.get_key_value(&id).unwrap(); + let module = self.library.get(module_id).unwrap(); + if !module.buses.contains_key(&id.id()) { + self.invalid_constraint(id.span(), "bus operations in constraints must be to a bus") + .with_secondary_label(import_id.span(), "the identifier imported here is not a bus") + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + } + ControlFlow::Continue(()) + } + } + } + id => panic!("invalid bus identifier, expected bus, got binding: {:#?}", id), + } + } + ResolvableIdentifier::Local(id) => { + self.invalid_callee(id.span(), "local variables", "A local binding with this name is in scope, but no such bus is declared in this module. Are you missing an import?") + } + ResolvableIdentifier::Global(id) => { + self.invalid_callee(id.span(), "global declarations", "A global declaration with this name is in scope, but no such such is declared in this module. Are you missing an import?") + } + ResolvableIdentifier::Unresolved(_) => ControlFlow::Continue(()), + } + } expr => { - self.invalid_constraint(expr.span(), "expected either an equality expression, or a call to an evaluator here") - .with_note("Integrity constraints must be expressed as an equality, e.g. `a = 0`, or a call, e.g. `evaluator(a)`") + self.invalid_constraint(expr.span(), "expected either an equality expression, a call to an evaluator, or a bus operation here") + .with_note("Integrity constraints must be expressed as an equality, e.g. `a = 0`, a call, e.g. `evaluator(a)`, or a bus operation, e.g. `p.add(a) when 1`") .emit(); ControlFlow::Break(SemanticAnalysisError::Invalid) } @@ -1679,6 +1849,11 @@ impl SemanticAnalysis<'_> { .emit(); Err(InvalidAccessError::InvalidBinding) } + // TODO BUS: Is it the correct binding type? Or do we want to throw an error? + // It seems that bus operations should be handled like Binary equality (enf a = 1) + // But it seems weird to assign a type to such operations + Expr::BusOperation(ref _expr) => Ok(BindingType::Local(Type::Felt)), + Expr::Null(_) => Ok(BindingType::Local(Type::Felt)), } } diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index d2979c4b4..ea3d11e76 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -75,6 +75,11 @@ impl<'a> ConstantPropagation<'a> { self.visit_mut_function(function)?; } + // Visit all of the buses + for bus in program.buses.values_mut() { + self.visit_mut_bus(bus)?; + } + // Visit all of the constraints self.visit_mut_boundary_constraints(&mut program.boundary_constraints)?; self.visit_mut_integrity_constraints(&mut program.integrity_constraints) @@ -176,7 +181,7 @@ impl VisitMut for ConstantPropagation<'_> { ) -> ControlFlow { match expr { // Expression is already folded - ScalarExpr::Const(_) => ControlFlow::Continue(()), + ScalarExpr::Const(_) | ScalarExpr::Null(_) => ControlFlow::Continue(()), // Need to check if this access is to a constant value, and transform to a constant if so ScalarExpr::SymbolAccess(sym) => { let constant_value = match sym.name { @@ -275,12 +280,14 @@ impl VisitMut for ConstantPropagation<'_> { } Statement::Enforce(_) | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) => unreachable!(), + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => unreachable!(), }, Err(err) => return ControlFlow::Break(err), } ControlFlow::Continue(()) } + ScalarExpr::BusOperation(ref mut expr) => self.visit_mut_bus_operation(expr), } } @@ -591,12 +598,15 @@ impl VisitMut for ConstantPropagation<'_> { } Statement::Enforce(_) | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) => unreachable!(), + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => unreachable!(), }, Err(err) => return ControlFlow::Break(err), } ControlFlow::Continue(()) } + Expr::BusOperation(ref mut expr) => self.visit_mut_bus_operation(expr), + Expr::Null(_) => ControlFlow::Continue(()), } } @@ -640,6 +650,11 @@ impl VisitMut for ConstantPropagation<'_> { Statement::Expr(ref mut expr) => { self.visit_mut_expr(expr)?; } + Statement::BusEnforce(ref mut expr) => { + self.in_constraint_comprehension = true; + self.visit_mut_list_comprehension(expr)?; + self.in_constraint_comprehension = false; + } // This statement type is only present in the AST after inlining Statement::EnforceIf(_, _) => unreachable!(), } diff --git a/parser/src/transforms/inlining.rs b/parser/src/transforms/inlining.rs index 51a444183..80b8e48a6 100644 --- a/parser/src/transforms/inlining.rs +++ b/parser/src/transforms/inlining.rs @@ -324,6 +324,13 @@ impl<'a> Inlining<'a> { Expr::Let(let_expr) => Ok(vec![Statement::Let(*let_expr)]), expr => Ok(vec![Statement::Expr(expr)]), }, + Statement::BusEnforce(_) => { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("buses are not implemented for this Pipeline") + .emit(); + Err(SemanticAnalysisError::Invalid) + } } } @@ -358,6 +365,13 @@ impl<'a> Inlining<'a> { Expr::try_from(block.pop().unwrap()).map_err(SemanticAnalysisError::InvalidExpr) } expr @ (Expr::Const(_) | Expr::Range(_) | Expr::SymbolAccess(_)) => Ok(expr), + Expr::BusOperation(_) | Expr::Null(_) => { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("buses are not implemented for this Pipeline") + .emit(); + Err(SemanticAnalysisError::Invalid) + } } } @@ -400,7 +414,8 @@ impl<'a> Inlining<'a> { Statement::Expr(expr) => expr, Statement::Enforce(_) | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) => unreachable!(), + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => unreachable!(), } } // The operands of a binary expression can contain function calls, so we must ensure @@ -543,7 +558,8 @@ impl<'a> Inlining<'a> { Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))), Statement::Enforce(_) | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) => unreachable!(), + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => unreachable!(), } } Expr::SymbolAccess(ref access) => { @@ -668,10 +684,18 @@ impl<'a> Inlining<'a> { } Statement::Enforce(_) | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) => unreachable!(), + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => unreachable!(), } } } + Expr::BusOperation(_) | Expr::Null(_) => { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("buses are not implemented for this Pipeline") + .emit(); + return Err(SemanticAnalysisError::Invalid); + } } Ok(()) } @@ -724,11 +748,19 @@ impl<'a> Inlining<'a> { } Statement::Enforce(_) | Statement::EnforceIf(_, _) - | Statement::EnforceAll(_) => unreachable!(), + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => unreachable!(), } } Ok(()) } + ScalarExpr::BusOperation(_) | ScalarExpr::Null(_) => { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("buses are not implemented for this Pipeline") + .emit(); + Err(SemanticAnalysisError::Invalid) + } } } @@ -999,7 +1031,12 @@ impl<'a> Inlining<'a> { // Binary expressions are scalar, so cannot be used as iterables, and we don't // (currently) support nested comprehensions, so it is never possible to observe // these expression types here. Calls should have been lifted prior to expansion. - Expr::Call(_) | Expr::Binary(_) | Expr::ListComprehension(_) | Expr::Let(_) => { + Expr::Call(_) + | Expr::Binary(_) + | Expr::ListComprehension(_) + | Expr::Let(_) + | Expr::BusOperation(_) + | Expr::Null(_) => { unreachable!() } }; @@ -1300,7 +1337,10 @@ impl<'a> Inlining<'a> { match function.body.pop().unwrap() { Statement::Expr(expr) => Ok(expr), Statement::Let(expr) => Ok(Expr::Let(Box::new(expr))), - Statement::Enforce(_) | Statement::EnforceIf(_, _) | Statement::EnforceAll(_) => { + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => { panic!("unexpected constraint in function body") } } @@ -1614,6 +1654,9 @@ fn eval_expr_binding_type( eval_expr_binding_type(&lc.iterables[0], bindings, imported) } Expr::Let(ref let_expr) => eval_let_binding_ty(let_expr, bindings, imported), + Expr::BusOperation(_) | Expr::Null(_) => { + unimplemented!("buses are not implemented for this Pipeline") + } } } @@ -1645,7 +1688,10 @@ fn eval_let_binding_ty( let binding_ty = match let_expr.body.last().unwrap() { Statement::Let(ref inner_let) => eval_let_binding_ty(inner_let, bindings, imported)?, Statement::Expr(ref expr) => eval_expr_binding_type(expr, bindings, imported)?, - Statement::Enforce(_) | Statement::EnforceIf(_, _) | Statement::EnforceAll(_) => { + Statement::Enforce(_) + | Statement::EnforceIf(_, _) + | Statement::EnforceAll(_) + | Statement::BusEnforce(_) => { unreachable!() } }; @@ -1745,7 +1791,14 @@ impl RewriteIterableBindingsVisitor<'_> { // These types of expressions will never be observed in this context, as they are // not valid iterable expressions (except calls, but those are lifted prior to rewrite // so that their use in this context is always a symbol access). - Some(Expr::Call(_) | Expr::Binary(_) | Expr::ListComprehension(_) | Expr::Let(_)) => { + Some( + Expr::Call(_) + | Expr::Binary(_) + | Expr::ListComprehension(_) + | Expr::Let(_) + | Expr::BusOperation(_) + | Expr::Null(_), + ) => { unreachable!() } None => None, @@ -1803,6 +1856,9 @@ impl VisitMut for RewriteIterableBindingsVisitor<'_> { // the case that we encounter a let here, as they can only be introduced in scalar // expression position as a result of inlining/expansion ScalarExpr::Let(_) => unreachable!(), + ScalarExpr::BusOperation(_) | ScalarExpr::Null(_) => { + ControlFlow::Break(SemanticAnalysisError::Invalid) + } } } } @@ -1844,6 +1900,7 @@ impl VisitMut for ApplyConstraintSelector<'_> { } Statement::EnforceAll(_) => unreachable!(), Statement::Expr(_) => ControlFlow::Continue(()), + Statement::BusEnforce(_) => ControlFlow::Break(SemanticAnalysisError::Invalid), } } }