Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
253 changes: 196 additions & 57 deletions src/libfuncs/bounded_int.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,31 +233,25 @@ fn build_sub<'ctx, 'this>(
} else {
rhs_range.zero_based_bit_width()
};
let dst_width = dst_range.offset_bit_width();

// Calculate the computation range.
let compute_range = Range {
lower: (&lhs_range.lower)
.min(&rhs_range.lower)
.min(&dst_range.lower)
.clone(),
upper: (&lhs_range.upper)
.max(&rhs_range.upper)
.max(&dst_range.upper)
.clone(),
};
let compute_ty = IntegerType::new(context, compute_range.offset_bit_width()).into();
let compile_time_val = lhs_range.lower.clone() - rhs_range.lower.clone() - dst_range.lower;
let compile_time_val_width = u32::try_from(compile_time_val.bits())?;

let compute_width = lhs_width.max(rhs_width).max(compile_time_val_width) + 1; // TODO: Check if the +1 is necessary
let compute_ty = IntegerType::new(context, compute_width).into();

// Zero-extend operands into the computation range.
native_assert!(
compute_range.offset_bit_width() >= lhs_width,
compute_width >= lhs_width,
"the lhs_range bit_width must be less or equal than the compute_range"
);
native_assert!(
compute_range.offset_bit_width() >= rhs_width,
compute_width >= rhs_width,
"the rhs_range bit_width must be less or equal than the compute_range"
);

let lhs_value = if compute_range.offset_bit_width() > lhs_width {
let lhs_value = if compute_width > lhs_width {
if lhs_range.lower.sign() != Sign::Minus || lhs_ty.is_bounded_int(registry)? {
entry.extui(lhs_value, compute_ty, location)?
} else {
Expand All @@ -266,7 +260,7 @@ fn build_sub<'ctx, 'this>(
} else {
lhs_value
};
let rhs_value = if compute_range.offset_bit_width() > rhs_width {
let rhs_value = if compute_width > rhs_width {
if rhs_range.lower.sign() != Sign::Minus || rhs_ty.is_bounded_int(registry)? {
entry.extui(rhs_value, compute_ty, location)?
} else {
Expand All @@ -276,47 +270,20 @@ fn build_sub<'ctx, 'this>(
rhs_value
};

// Offset the operands so that they are compatible.
let lhs_offset = if lhs_ty.is_bounded_int(registry)? {
&lhs_range.lower - &compute_range.lower
} else {
lhs_range.lower
};
let lhs_value = if lhs_offset != BigInt::ZERO {
let lhs_offset = entry.const_int_from_type(context, location, lhs_offset, compute_ty)?;
entry.addi(lhs_value, lhs_offset, location)?
} else {
lhs_value
};

let rhs_offset = if rhs_ty.is_bounded_int(registry)? {
&rhs_range.lower - &compute_range.lower
} else {
rhs_range.lower
};
let rhs_value = if rhs_offset != BigInt::ZERO {
let rhs_offset = entry.const_int_from_type(context, location, rhs_offset, compute_ty)?;
entry.addi(rhs_value, rhs_offset, location)?
} else {
rhs_value
};

// Compute the operation.
let res_value = entry.append_op_result(arith::subi(lhs_value, rhs_value, location))?;

// Offset and truncate the result to the output type.
let res_offset = dst_range.lower.clone();
let res_value = if res_offset != BigInt::ZERO {
let res_offset = entry.const_int_from_type(context, location, res_offset, compute_ty)?;
entry.append_op_result(arith::subi(res_value, res_offset, location))?
} else {
res_value
};

let res_value = if dst_range.offset_bit_width() < compute_range.offset_bit_width() {
let compile_time_val =
entry.const_int_from_type(context, location, compile_time_val, compute_ty)?;
let res_value = entry.subi(lhs_value, rhs_value, location)?;
let res_value = entry.addi(res_value, compile_time_val, location)?;
let res_value = if compute_width > dst_width {
entry.trunci(
res_value,
IntegerType::new(context, dst_range.offset_bit_width()).into(),
IntegerType::new(context, dst_width).into(),
location,
)?
} else if compute_width < dst_width {
entry.extui(
res_value,
IntegerType::new(context, dst_width).into(),
location,
)?
} else {
Expand Down Expand Up @@ -871,8 +838,12 @@ mod test {
use num_bigint::BigInt;

use crate::{
context::NativeContext, execution_result::ExecutionResult, executor::JitNativeExecutor,
load_cairo, utils::testing::run_program, OptLevel, Value,
context::NativeContext,
execution_result::ExecutionResult,
executor::JitNativeExecutor,
jit_enum, jit_struct, load_cairo,
utils::testing::{run_program, run_program_assert_output},
OptLevel, Value,
};

#[test]
Expand Down Expand Up @@ -1007,6 +978,174 @@ mod test {
assert_eq!(value, Felt252::from(0));
}

#[test]
fn test_sub() {
let cairo = load_cairo! {
#[feature("bounded-int-utils")]
use core::internal::bounded_int::{BoundedInt, sub, SubHelper};

impl SubHelper1 of SubHelper<BoundedInt<1, 1>, BoundedInt<1, 5>> {
type Result = BoundedInt<-4, 0>;
}

fn run_test_1(
a: felt252,
b: felt252,
) -> BoundedInt<-4, 0> {
let a: BoundedInt<1, 1> = a.try_into().unwrap();
let b: BoundedInt<1, 5> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelper2 of SubHelper<BoundedInt<1, 1>, BoundedInt<1, 1>> {
type Result = BoundedInt<0, 0>;
}

fn run_test_2(
a: felt252,
b: felt252,
) -> BoundedInt<0, 0> {
let a: BoundedInt<1, 1> = a.try_into().unwrap();
let b: BoundedInt<1, 1> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelper3 of SubHelper<BoundedInt<-3, -3>, BoundedInt<-3, -3>> {
type Result = BoundedInt<0, 0>;
}

fn run_test_3(
a: felt252,
b: felt252,
) -> BoundedInt<0, 0> {
let a: BoundedInt<-3, -3> = a.try_into().unwrap();
let b: BoundedInt<-3, -3> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelper4 of SubHelper<BoundedInt<-6, -3>, BoundedInt<1, 3>> {
type Result = BoundedInt<-9, -4>;
}

fn run_test_4(
a: felt252,
b: felt252,
) -> BoundedInt<-9, -4> {
let a: BoundedInt<-6, -3> = a.try_into().unwrap();
let b: BoundedInt<1, 3> = b.try_into().unwrap();
return sub(a, b);
}

impl SubHelper5 of SubHelper<BoundedInt<-6, -2>, BoundedInt<-20, -10>> {
type Result = BoundedInt<4, 18>;
}

fn run_test_5(
a: felt252,
b: felt252,
) -> BoundedInt<4, 18> {
let a: BoundedInt<-6, -2> = a.try_into().unwrap();
let b: BoundedInt<-20, -10> = b.try_into().unwrap();
return sub(a, b);
}
};

run_program_assert_output(
&cairo,
"run_test_1",
&[
Value::Felt252(Felt252::from(1)),
Value::Felt252(Felt252::from(5)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(-4),
range: Range {
lower: BigInt::from(-4),
upper: BigInt::from(1),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_2",
&[
Value::Felt252(Felt252::from(1)),
Value::Felt252(Felt252::from(1)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(0),
range: Range {
lower: BigInt::from(0),
upper: BigInt::from(1),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_3",
&[
Value::Felt252(Felt252::from(-3)),
Value::Felt252(Felt252::from(-3)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(0),
range: Range {
lower: BigInt::from(0),
upper: BigInt::from(1),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_4",
&[
Value::Felt252(Felt252::from(-6)),
Value::Felt252(Felt252::from(3)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(-9),
range: Range {
lower: BigInt::from(-9),
upper: BigInt::from(-3),
}
})
),
);

run_program_assert_output(
&cairo,
"run_test_5",
&[
Value::Felt252(Felt252::from(-2)),
Value::Felt252(Felt252::from(-20)),
],
jit_enum!(
0,
jit_struct!(Value::BoundedInt {
value: Felt252::from(18),
range: Range {
lower: BigInt::from(4),
upper: BigInt::from(19),
}
})
),
)
}

fn assert_bool_output(result: Value, expected_tag: usize) {
if let Value::Enum { tag, value, .. } = result {
assert_eq!(tag, 0);
Expand Down
Loading