diff --git a/src/libfuncs/bounded_int.rs b/src/libfuncs/bounded_int.rs index d95c02e72..2a2fcd086 100644 --- a/src/libfuncs/bounded_int.rs +++ b/src/libfuncs/bounded_int.rs @@ -200,6 +200,10 @@ fn build_add<'ctx, 'this>( } /// Generate MLIR operations for the `bounded_int_sub` libfunc. +/// +/// Since we want to get C = A - B, we can translate this to +/// Co + Cd = (Ao + Ad) - (Bo + Bd). Where Ao, Bo and Co represent the lower bound +/// of the ranges in the BoundedInt and Ad, Bd and Cd represent the offsets. #[allow(clippy::too_many_arguments)] fn build_sub<'ctx, 'this>( context: &'ctx Context, @@ -213,7 +217,7 @@ fn build_sub<'ctx, 'this>( let lhs_value = entry.arg(0)?; let rhs_value = entry.arg(1)?; - // Extract the ranges for the operands and the result type. + // Extract the ranges for the operands. let lhs_ty = registry.get_type(&info.signature.param_signatures[0].ty)?; let rhs_ty = registry.get_type(&info.signature.param_signatures[1].ty)?; @@ -223,6 +227,7 @@ fn build_sub<'ctx, 'this>( .get_type(&info.signature.branch_signatures[0].vars[0].ty)? .integer_range(registry)?; + // Extract the bit width. let lhs_width = if lhs_ty.is_bounded_int(registry)? { lhs_range.offset_bit_width() } else { @@ -233,31 +238,26 @@ fn build_sub<'ctx, 'this>( } else { rhs_range.zero_based_bit_width() }; + let dst_width = dst_range.offset_bit_width(); - // Calculate the computation range. - let compute_range = Range { - lower: (&lhs_range.lower) - .min(&rhs_range.lower) - .min(&dst_range.lower) - .clone(), - upper: (&lhs_range.upper) - .max(&rhs_range.upper) - .max(&dst_range.upper) - .clone(), - }; - let compute_ty = IntegerType::new(context, compute_range.offset_bit_width()).into(); + // Get the compute type so we can do the subtraction without problems + let compile_time_val = lhs_range.lower.clone() - rhs_range.lower.clone() - dst_range.lower; + let compile_time_val_width = u32::try_from(compile_time_val.bits())?; + + let compute_width = lhs_width.max(rhs_width).max(compile_time_val_width) + 1; + let compute_ty = IntegerType::new(context, compute_width).into(); - // Zero-extend operands into the computation range. native_assert!( - compute_range.offset_bit_width() >= lhs_width, + compute_width >= lhs_width, "the lhs_range bit_width must be less or equal than the compute_range" ); native_assert!( - compute_range.offset_bit_width() >= rhs_width, + compute_width >= rhs_width, "the rhs_range bit_width must be less or equal than the compute_range" ); - let lhs_value = if compute_range.offset_bit_width() > lhs_width { + // Get the operands on the same number of bits so we can operate with them + let lhs_value = if compute_width > lhs_width { if lhs_range.lower.sign() != Sign::Minus || lhs_ty.is_bounded_int(registry)? { entry.extui(lhs_value, compute_ty, location)? } else { @@ -266,7 +266,7 @@ fn build_sub<'ctx, 'this>( } else { lhs_value }; - let rhs_value = if compute_range.offset_bit_width() > rhs_width { + let rhs_value = if compute_width > rhs_width { if rhs_range.lower.sign() != Sign::Minus || rhs_ty.is_bounded_int(registry)? { entry.extui(rhs_value, compute_ty, location)? } else { @@ -276,47 +276,23 @@ fn build_sub<'ctx, 'this>( rhs_value }; - // Offset the operands so that they are compatible. - let lhs_offset = if lhs_ty.is_bounded_int(registry)? { - &lhs_range.lower - &compute_range.lower - } else { - lhs_range.lower - }; - let lhs_value = if lhs_offset != BigInt::ZERO { - let lhs_offset = entry.const_int_from_type(context, location, lhs_offset, compute_ty)?; - entry.addi(lhs_value, lhs_offset, location)? - } else { - lhs_value - }; - - let rhs_offset = if rhs_ty.is_bounded_int(registry)? { - &rhs_range.lower - &compute_range.lower - } else { - rhs_range.lower - }; - let rhs_value = if rhs_offset != BigInt::ZERO { - let rhs_offset = entry.const_int_from_type(context, location, rhs_offset, compute_ty)?; - entry.addi(rhs_value, rhs_offset, location)? - } else { - rhs_value - }; - - // Compute the operation. - let res_value = entry.append_op_result(arith::subi(lhs_value, rhs_value, location))?; - - // Offset and truncate the result to the output type. - let res_offset = dst_range.lower.clone(); - let res_value = if res_offset != BigInt::ZERO { - let res_offset = entry.const_int_from_type(context, location, res_offset, compute_ty)?; - entry.append_op_result(arith::subi(res_value, res_offset, location))? - } else { - res_value - }; - - let res_value = if dst_range.offset_bit_width() < compute_range.offset_bit_width() { + let compile_time_val = + entry.const_int_from_type(context, location, compile_time_val, compute_ty)?; + // First we do -> Ad - Bd = intermediate_res + let res_value = entry.subi(lhs_value, rhs_value, location)?; + // Then we do -> intermediate_res + (Ao - Bo - Co) + let res_value = entry.addi(res_value, compile_time_val, location)?; + // Get the result value on the desired range + let res_value = if compute_width > dst_width { entry.trunci( res_value, - IntegerType::new(context, dst_range.offset_bit_width()).into(), + IntegerType::new(context, dst_width).into(), + location, + )? + } else if compute_width < dst_width { + entry.extui( + res_value, + IntegerType::new(context, dst_width).into(), location, )? } else { @@ -871,8 +847,12 @@ mod test { use num_bigint::BigInt; use crate::{ - context::NativeContext, execution_result::ExecutionResult, executor::JitNativeExecutor, - load_cairo, utils::testing::run_program, OptLevel, Value, + context::NativeContext, + execution_result::ExecutionResult, + executor::JitNativeExecutor, + jit_enum, jit_struct, load_cairo, + utils::testing::{run_program, run_program_assert_output}, + OptLevel, Value, }; #[test] @@ -1007,6 +987,174 @@ mod test { assert_eq!(value, Felt252::from(0)); } + #[test] + fn test_sub() { + let cairo = load_cairo! { + #[feature("bounded-int-utils")] + use core::internal::bounded_int::{BoundedInt, sub, SubHelper}; + + impl SubHelper1 of SubHelper, BoundedInt<1, 5>> { + type Result = BoundedInt<-4, 0>; + } + + fn run_test_1( + a: felt252, + b: felt252, + ) -> BoundedInt<-4, 0> { + let a: BoundedInt<1, 1> = a.try_into().unwrap(); + let b: BoundedInt<1, 5> = b.try_into().unwrap(); + return sub(a, b); + } + + impl SubHelper2 of SubHelper, BoundedInt<1, 1>> { + type Result = BoundedInt<0, 0>; + } + + fn run_test_2( + a: felt252, + b: felt252, + ) -> BoundedInt<0, 0> { + let a: BoundedInt<1, 1> = a.try_into().unwrap(); + let b: BoundedInt<1, 1> = b.try_into().unwrap(); + return sub(a, b); + } + + impl SubHelper3 of SubHelper, BoundedInt<-3, -3>> { + type Result = BoundedInt<0, 0>; + } + + fn run_test_3( + a: felt252, + b: felt252, + ) -> BoundedInt<0, 0> { + let a: BoundedInt<-3, -3> = a.try_into().unwrap(); + let b: BoundedInt<-3, -3> = b.try_into().unwrap(); + return sub(a, b); + } + + impl SubHelper4 of SubHelper, BoundedInt<1, 3>> { + type Result = BoundedInt<-9, -4>; + } + + fn run_test_4( + a: felt252, + b: felt252, + ) -> BoundedInt<-9, -4> { + let a: BoundedInt<-6, -3> = a.try_into().unwrap(); + let b: BoundedInt<1, 3> = b.try_into().unwrap(); + return sub(a, b); + } + + impl SubHelper5 of SubHelper, BoundedInt<-20, -10>> { + type Result = BoundedInt<4, 18>; + } + + fn run_test_5( + a: felt252, + b: felt252, + ) -> BoundedInt<4, 18> { + let a: BoundedInt<-6, -2> = a.try_into().unwrap(); + let b: BoundedInt<-20, -10> = b.try_into().unwrap(); + return sub(a, b); + } + }; + + run_program_assert_output( + &cairo, + "run_test_1", + &[ + Value::Felt252(Felt252::from(1)), + Value::Felt252(Felt252::from(5)), + ], + jit_enum!( + 0, + jit_struct!(Value::BoundedInt { + value: Felt252::from(-4), + range: Range { + lower: BigInt::from(-4), + upper: BigInt::from(1), + } + }) + ), + ); + + run_program_assert_output( + &cairo, + "run_test_2", + &[ + Value::Felt252(Felt252::from(1)), + Value::Felt252(Felt252::from(1)), + ], + jit_enum!( + 0, + jit_struct!(Value::BoundedInt { + value: Felt252::from(0), + range: Range { + lower: BigInt::from(0), + upper: BigInt::from(1), + } + }) + ), + ); + + run_program_assert_output( + &cairo, + "run_test_3", + &[ + Value::Felt252(Felt252::from(-3)), + Value::Felt252(Felt252::from(-3)), + ], + jit_enum!( + 0, + jit_struct!(Value::BoundedInt { + value: Felt252::from(0), + range: Range { + lower: BigInt::from(0), + upper: BigInt::from(1), + } + }) + ), + ); + + run_program_assert_output( + &cairo, + "run_test_4", + &[ + Value::Felt252(Felt252::from(-6)), + Value::Felt252(Felt252::from(3)), + ], + jit_enum!( + 0, + jit_struct!(Value::BoundedInt { + value: Felt252::from(-9), + range: Range { + lower: BigInt::from(-9), + upper: BigInt::from(-3), + } + }) + ), + ); + + run_program_assert_output( + &cairo, + "run_test_5", + &[ + Value::Felt252(Felt252::from(-2)), + Value::Felt252(Felt252::from(-20)), + ], + jit_enum!( + 0, + jit_struct!(Value::BoundedInt { + value: Felt252::from(18), + range: Range { + lower: BigInt::from(4), + upper: BigInt::from(19), + } + }) + ), + ) + } + fn assert_bool_output(result: Value, expected_tag: usize) { if let Value::Enum { tag, value, .. } = result { assert_eq!(tag, 0);