diff --git a/crates/air/benches/constraint_bench.rs b/crates/air/benches/constraint_bench.rs index f122f92..46ddf63 100644 --- a/crates/air/benches/constraint_bench.rs +++ b/crates/air/benches/constraint_bench.rs @@ -4,108 +4,108 @@ //! //! This benchmark compares bit-based vs lookup-based bitwise constraints. -use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use zp1_air::rv32im::{ConstraintEvaluator, CpuTraceRow}; use zp1_primitives::M31; -use zp1_air::rv32im::{CpuTraceRow, ConstraintEvaluator}; /// Create a test row for bitwise AND operation. fn create_and_row() -> CpuTraceRow { let mut row = CpuTraceRow::default(); - + // AND: 0x12345678 & 0x0F0F0F0F = 0x02040608 row.is_and = M31::ONE; - + // rs1 = 0x12345678 row.rs1_val_lo = M31::new(0x5678); row.rs1_val_hi = M31::new(0x1234); - + // rs2 = 0x0F0F0F0F row.rs2_val_lo = M31::new(0x0F0F); row.rs2_val_hi = M31::new(0x0F0F); - + // Result = 0x02040608 row.rd_val_lo = M31::new(0x0608); row.rd_val_hi = M31::new(0x0204); - + // Bit decomposition for bit-based constraints let rs1: u32 = 0x12345678; let rs2: u32 = 0x0F0F0F0F; let result = rs1 & rs2; - + for i in 0..32 { row.rs1_bits[i] = M31::new(((rs1 >> i) & 1) as u32); row.rs2_bits[i] = M31::new(((rs2 >> i) & 1) as u32); row.and_bits[i] = M31::new(((result >> i) & 1) as u32); } - + // Byte decomposition for lookup-based constraints row.rs1_bytes[0] = M31::new(0x78); row.rs1_bytes[1] = M31::new(0x56); row.rs1_bytes[2] = M31::new(0x34); row.rs1_bytes[3] = M31::new(0x12); - + row.rs2_bytes[0] = M31::new(0x0F); row.rs2_bytes[1] = M31::new(0x0F); row.rs2_bytes[2] = M31::new(0x0F); row.rs2_bytes[3] = M31::new(0x0F); - + row.and_result_bytes[0] = M31::new(0x08); row.and_result_bytes[1] = M31::new(0x06); row.and_result_bytes[2] = M31::new(0x04); row.and_result_bytes[3] = M31::new(0x02); - + row } /// Create a test row for bitwise XOR operation. fn create_xor_row() -> CpuTraceRow { let mut row = CpuTraceRow::default(); - + // XOR: 0xAAAAAAAA ^ 0x55555555 = 0xFFFFFFFF row.is_xor = M31::ONE; - + row.rs1_val_lo = M31::new(0xAAAA); row.rs1_val_hi = M31::new(0xAAAA); - + row.rs2_val_lo = M31::new(0x5555); row.rs2_val_hi = M31::new(0x5555); - + row.rd_val_lo = M31::new(0xFFFF); row.rd_val_hi = M31::new(0xFFFF); - + // Bit decomposition let rs1: u32 = 0xAAAAAAAA; let rs2: u32 = 0x55555555; let result = rs1 ^ rs2; - + for i in 0..32 { row.rs1_bits[i] = M31::new(((rs1 >> i) & 1) as u32); row.rs2_bits[i] = M31::new(((rs2 >> i) & 1) as u32); row.xor_bits[i] = M31::new(((result >> i) & 1) as u32); } - + // Byte decomposition row.rs1_bytes[0] = M31::new(0xAA); row.rs1_bytes[1] = M31::new(0xAA); row.rs1_bytes[2] = M31::new(0xAA); row.rs1_bytes[3] = M31::new(0xAA); - + row.rs2_bytes[0] = M31::new(0x55); row.rs2_bytes[1] = M31::new(0x55); row.rs2_bytes[2] = M31::new(0x55); row.rs2_bytes[3] = M31::new(0x55); - + row.xor_result_bytes[0] = M31::new(0xFF); row.xor_result_bytes[1] = M31::new(0xFF); row.xor_result_bytes[2] = M31::new(0xFF); row.xor_result_bytes[3] = M31::new(0xFF); - + row } fn bench_and_bit_based(c: &mut Criterion) { let row = create_and_row(); - + c.bench_function("AND_bit_based", |b| { b.iter(|| ConstraintEvaluator::and_constraint(black_box(&row))) }); @@ -113,7 +113,7 @@ fn bench_and_bit_based(c: &mut Criterion) { fn bench_and_lookup_based(c: &mut Criterion) { let row = create_and_row(); - + c.bench_function("AND_lookup_based", |b| { b.iter(|| ConstraintEvaluator::and_constraint_lookup(black_box(&row))) }); @@ -121,7 +121,7 @@ fn bench_and_lookup_based(c: &mut Criterion) { fn bench_xor_bit_based(c: &mut Criterion) { let row = create_xor_row(); - + c.bench_function("XOR_bit_based", |b| { b.iter(|| ConstraintEvaluator::xor_constraint(black_box(&row))) }); @@ -129,7 +129,7 @@ fn bench_xor_bit_based(c: &mut Criterion) { fn bench_xor_lookup_based(c: &mut Criterion) { let row = create_xor_row(); - + c.bench_function("XOR_lookup_based", |b| { b.iter(|| ConstraintEvaluator::xor_constraint_lookup(black_box(&row))) }); @@ -137,50 +137,46 @@ fn bench_xor_lookup_based(c: &mut Criterion) { fn bench_bitwise_comparison(c: &mut Criterion) { let mut group = c.benchmark_group("Bitwise_Constraints"); - + let and_row = create_and_row(); let xor_row = create_xor_row(); - + group.bench_function("AND/bit_based", |b| { b.iter(|| ConstraintEvaluator::and_constraint(black_box(&and_row))) }); - + group.bench_function("AND/lookup_based", |b| { b.iter(|| ConstraintEvaluator::and_constraint_lookup(black_box(&and_row))) }); - + group.bench_function("XOR/bit_based", |b| { b.iter(|| ConstraintEvaluator::xor_constraint(black_box(&xor_row))) }); - + group.bench_function("XOR/lookup_based", |b| { b.iter(|| ConstraintEvaluator::xor_constraint_lookup(black_box(&xor_row))) }); - + group.finish(); } /// Benchmark evaluating many rows (simulating trace evaluation) fn bench_batch_evaluation(c: &mut Criterion) { let mut group = c.benchmark_group("Batch_Evaluation"); - + for num_rows in [100, 1000, 10000] { let rows: Vec = (0..num_rows).map(|_| create_and_row()).collect(); - - group.bench_with_input( - BenchmarkId::new("bit_based", num_rows), - &rows, - |b, rows| { - b.iter(|| { - let mut sum = M31::ZERO; - for row in rows.iter() { - sum = sum + ConstraintEvaluator::and_constraint(black_box(row)); - } - sum - }) - }, - ); - + + group.bench_with_input(BenchmarkId::new("bit_based", num_rows), &rows, |b, rows| { + b.iter(|| { + let mut sum = M31::ZERO; + for row in rows.iter() { + sum = sum + ConstraintEvaluator::and_constraint(black_box(row)); + } + sum + }) + }); + group.bench_with_input( BenchmarkId::new("lookup_based", num_rows), &rows, @@ -195,7 +191,7 @@ fn bench_batch_evaluation(c: &mut Criterion) { }, ); } - + group.finish(); } diff --git a/crates/air/src/cpu.rs b/crates/air/src/cpu.rs index abfaefb..be0ff5d 100644 --- a/crates/air/src/cpu.rs +++ b/crates/air/src/cpu.rs @@ -10,12 +10,12 @@ pub struct CpuAir; impl CpuAir { /// Evaluate the x0 = 0 constraint. /// When writing to x0 (is_write_x0 selector = 1), rd_val must be 0. - /// + /// /// # Arguments /// * `is_write_x0` - Boolean selector (1 if writing to x0, 0 otherwise) /// * `rd_val_lo` - Lower 16-bit limb of value being written /// * `rd_val_hi` - Upper 16-bit limb of value being written - /// + /// /// # Returns /// Sum of two constraints (one per limb): is_write_x0 * rd_val_lo + is_write_x0 * rd_val_hi #[inline] @@ -95,9 +95,7 @@ impl CpuAir { ) -> (M31, M31) { // Same as ADD but with immediate Self::add_constraint( - is_addi, rd_val_lo, rd_val_hi, - rs1_val_lo, rs1_val_hi, - imm_lo, imm_hi, carry + is_addi, rd_val_lo, rd_val_hi, rs1_val_lo, rs1_val_hi, imm_lo, imm_hi, carry, ) } @@ -161,33 +159,21 @@ impl CpuAir { /// Evaluate SLLI (Shift Left Logical Immediate) constraint. /// rd_val = rs1_val << shamt - pub fn slli_constraint( - bits_rs1: &[M31; 32], - bits_result: &[M31; 32], - shamt: M31, - ) -> Vec { + pub fn slli_constraint(bits_rs1: &[M31; 32], bits_result: &[M31; 32], shamt: M31) -> Vec { // Same as SLL but with immediate shift amount Self::shift_left_logical_constraints(bits_rs1, bits_result, shamt) } /// Evaluate SRLI (Shift Right Logical Immediate) constraint. /// rd_val = rs1_val >> shamt - pub fn srli_constraint( - bits_rs1: &[M31; 32], - bits_result: &[M31; 32], - shamt: M31, - ) -> Vec { + pub fn srli_constraint(bits_rs1: &[M31; 32], bits_result: &[M31; 32], shamt: M31) -> Vec { // Same as SRL but with immediate shift amount Self::shift_right_logical_constraints(bits_rs1, bits_result, shamt) } /// Evaluate SRAI (Shift Right Arithmetic Immediate) constraint. /// rd_val = rs1_val >> shamt (sign-extended) - pub fn srai_constraint( - bits_rs1: &[M31; 32], - bits_result: &[M31; 32], - shamt: M31, - ) -> Vec { + pub fn srai_constraint(bits_rs1: &[M31; 32], bits_result: &[M31; 32], shamt: M31) -> Vec { // Same as SRA but with immediate shift amount Self::shift_right_arithmetic_constraints(bits_rs1, bits_result, shamt) } @@ -210,39 +196,39 @@ impl CpuAir { bits: &[M31; 32], ) -> Vec { let mut constraints = Vec::with_capacity(34); - + // Constraint: each bit must be 0 or 1 // bit * (bit - 1) = 0 for &bit in bits { constraints.push(bit * (bit - M31::ONE)); } - + // Constraint: bits must reconstruct the value // value = bits[0] + 2*bits[1] + 4*bits[2] + ... + 2^31*bits[31] let mut recon_lo = M31::ZERO; let mut recon_hi = M31::ZERO; let mut power = M31::ONE; - + for i in 0..32 { if i < 16 { recon_lo = recon_lo + bits[i] * power; } else { recon_hi = recon_hi + bits[i] * power; } - + // Update power: multiply by 2 (mod p) power = power + power; - + // After bit 15, reset power for high limb if i == 15 { power = M31::ONE; } } - + // Reconstruction constraints constraints.push(value_lo - recon_lo); constraints.push(value_hi - recon_hi); - + constraints } @@ -299,14 +285,15 @@ impl CpuAir { let two = M31::new(2); for i in 0..32 { // result[i] = a[i] + b[i] - 2*a[i]*b[i] - constraints.push(bits_result[i] - (bits_a[i] + bits_b[i] - two * bits_a[i] * bits_b[i])); + constraints + .push(bits_result[i] - (bits_a[i] + bits_b[i] - two * bits_a[i] * bits_b[i])); } constraints } /// Evaluate SLL (Shift Left Logical) constraint. /// result = value << (shift_amount % 32) - /// + /// /// # Arguments /// * `bits_value` - Bit decomposition of input value /// * `bits_result` - Bit decomposition of result @@ -320,15 +307,15 @@ impl CpuAir { shift_amount: M31, ) -> Vec { let mut constraints = Vec::with_capacity(32); - + // For each possible shift amount (0-31), we need to check: // If shift_amount == k, then result[i] = value[i-k] for i >= k, else 0 // We use selector pattern: is_shift_k * (result[i] - expected[i]) = 0 - + // Convert shift_amount to u32 for computation // Note: In real implementation, shift_amount should be range-checked [0, 31] let shift_val = shift_amount.value() % 32; - + for i in 0..32 { if i < shift_val as usize { // Bits shifted in from right are 0 @@ -339,7 +326,7 @@ impl CpuAir { constraints.push(bits_result[i] - bits_value[src_idx]); } } - + constraints } @@ -355,9 +342,9 @@ impl CpuAir { shift_amount: M31, ) -> Vec { let mut constraints = Vec::with_capacity(32); - + let shift_val = shift_amount.value() % 32; - + for i in 0..32 { let src_idx = i + shift_val as usize; if src_idx >= 32 { @@ -368,7 +355,7 @@ impl CpuAir { constraints.push(bits_result[i] - bits_value[src_idx]); } } - + constraints } @@ -384,10 +371,10 @@ impl CpuAir { shift_amount: M31, ) -> Vec { let mut constraints = Vec::with_capacity(32); - + let shift_val = shift_amount.value() % 32; let sign_bit = bits_value[31]; // MSB is sign bit - + for i in 0..32 { let src_idx = i + shift_val as usize; if src_idx >= 32 { @@ -398,7 +385,7 @@ impl CpuAir { constraints.push(bits_result[i] - bits_value[src_idx]); } } - + constraints } @@ -424,32 +411,32 @@ impl CpuAir { diff_bits: &[M31; 32], ) -> Vec { let mut constraints = Vec::new(); - + // Constraint 1: result must be binary (0 or 1) constraints.push(result * (result - M31::ONE)); - + // Constraint 2: Check sign bits for signed comparison // If sign(a) != sign(b): // result = sign(a) (1 if a is negative, 0 if a is positive) // If sign(a) == sign(b): // result = sign(a - b) - + let sign_a = bits_a[31]; let sign_b = bits_b[31]; let sign_diff = diff_bits[31]; - + // Case 1: Different signs // If a is negative and b is positive: result = 1 // If a is positive and b is negative: result = 0 let diff_signs = sign_a * (M31::ONE - sign_b); // 1 if a<0 and b>=0 - + // Case 2: Same signs - use difference sign let same_signs = M31::ONE - sign_a - sign_b + sign_a * sign_b * M31::new(2); let diff_result = same_signs * sign_diff; - + // Combined: result = diff_signs + diff_result constraints.push(result - diff_signs - diff_result); - + constraints } @@ -473,17 +460,17 @@ impl CpuAir { borrow: M31, ) -> Vec { let mut constraints = Vec::new(); - + // Constraint 1: result must be binary (0 or 1) constraints.push(result * (result - M31::ONE)); - + // Constraint 2: borrow must be binary (0 or 1) constraints.push(borrow * (borrow - M31::ONE)); - + // Constraint 3: For unsigned, a < b iff borrow occurred in a - b // result = borrow constraints.push(result - borrow); - + constraints } @@ -511,21 +498,21 @@ impl CpuAir { borrow: M31, ) -> (M31, M31) { let two_16 = M31::new(1 << 16); - + // Low limb: result_lo + b_lo = a_lo + borrow * 2^16 // If a_lo < b_lo, we borrow from high (borrow = 1) let c_lo = a_lo + borrow * two_16 - b_lo - result_lo; - + // High limb: result_hi + b_hi + borrow = a_hi // We subtract the borrowed amount from high limb let c_hi = a_hi - b_hi - borrow - result_hi; - + (c_lo, c_hi) } /// Evaluate LB (Load Byte) constraint. /// rd = sign_extend(mem[addr][7:0]) - /// + /// /// # Arguments /// * `mem_value` - Full 32-bit word from memory /// * `byte_offset` - Which byte to load (0-3) @@ -533,16 +520,12 @@ impl CpuAir { /// /// # Returns /// Constraint ensuring correct byte extraction and sign extension - pub fn load_byte_constraint( - mem_value: M31, - byte_offset: M31, - rd_val: M31, - ) -> M31 { + pub fn load_byte_constraint(mem_value: M31, byte_offset: M31, rd_val: M31) -> M31 { // Extract byte based on offset (0-3) // For simplicity, assume byte_offset is validated elsewhere // byte = (mem_value >> (8 * byte_offset)) & 0xFF // sign_extend = byte < 128 ? byte : byte | 0xFFFFFF00 - + // This requires bit decomposition of the byte // Placeholder: will be implemented with proper bit extraction mem_value - rd_val - byte_offset + mem_value // Placeholder identity @@ -558,14 +541,10 @@ impl CpuAir { /// /// # Returns /// Constraint ensuring correct halfword extraction and sign extension - pub fn load_halfword_constraint( - mem_value: M31, - half_offset: M31, - rd_val: M31, - ) -> M31 { + pub fn load_halfword_constraint(mem_value: M31, half_offset: M31, rd_val: M31) -> M31 { // Extract halfword: (mem_value >> (16 * half_offset)) & 0xFFFF // sign_extend = half < 32768 ? half : half | 0xFFFF0000 - + // Placeholder: requires proper extraction logic mem_value - rd_val - half_offset + mem_value } @@ -580,10 +559,7 @@ impl CpuAir { /// # Returns /// Constraint: rd_val = mem_value #[inline] - pub fn load_word_constraint( - mem_value: M31, - rd_val: M31, - ) -> M31 { + pub fn load_word_constraint(mem_value: M31, rd_val: M31) -> M31 { rd_val - mem_value } @@ -597,14 +573,10 @@ impl CpuAir { /// /// # Returns /// Constraint ensuring correct byte extraction and zero extension - pub fn load_byte_unsigned_constraint( - mem_value: M31, - byte_offset: M31, - rd_val: M31, - ) -> M31 { + pub fn load_byte_unsigned_constraint(mem_value: M31, byte_offset: M31, rd_val: M31) -> M31 { // byte = (mem_value >> (8 * byte_offset)) & 0xFF // zero_extend = byte (no sign extension) - + // Placeholder mem_value - rd_val - byte_offset + mem_value } @@ -619,14 +591,10 @@ impl CpuAir { /// /// # Returns /// Constraint ensuring correct halfword extraction and zero extension - pub fn load_halfword_unsigned_constraint( - mem_value: M31, - half_offset: M31, - rd_val: M31, - ) -> M31 { + pub fn load_halfword_unsigned_constraint(mem_value: M31, half_offset: M31, rd_val: M31) -> M31 { // half = (mem_value >> (16 * half_offset)) & 0xFFFF // zero_extend = half - + // Placeholder mem_value - rd_val - half_offset + mem_value } @@ -650,7 +618,7 @@ impl CpuAir { ) -> M31 { // Mask out target byte, insert new byte // new = (old & ~(0xFF << (8*offset))) | ((byte & 0xFF) << (8*offset)) - + // Placeholder old_mem_value - new_mem_value - byte_to_store - byte_offset + old_mem_value } @@ -673,7 +641,7 @@ impl CpuAir { half_offset: M31, ) -> M31 { // new = (old & ~(0xFFFF << (16*offset))) | ((half & 0xFFFF) << (16*offset)) - + // Placeholder old_mem_value - new_mem_value - half_to_store - half_offset + old_mem_value } @@ -688,10 +656,7 @@ impl CpuAir { /// # Returns /// Constraint: new_mem_value = rs2_val #[inline] - pub fn store_word_constraint( - new_mem_value: M31, - rs2_val: M31, - ) -> M31 { + pub fn store_word_constraint(new_mem_value: M31, rs2_val: M31) -> M31 { new_mem_value - rs2_val } @@ -704,13 +669,10 @@ impl CpuAir { /// /// # Returns /// Constraint: is_word_access * (addr_lo % 4) = 0 - pub fn word_alignment_constraint( - addr_lo: M31, - is_word_access: M31, - ) -> M31 { + pub fn word_alignment_constraint(addr_lo: M31, is_word_access: M31) -> M31 { // addr_lo % 4 = addr_lo & 3 // Need bit decomposition to check low 2 bits are 0 - + // Simplified: Check addr_lo mod 4 via auxiliary witness // For now, placeholder assuming alignment is pre-checked is_word_access * (addr_lo - addr_lo) // Identity @@ -725,13 +687,10 @@ impl CpuAir { /// /// # Returns /// Constraint: is_half_access * (addr_lo % 2) = 0 - pub fn halfword_alignment_constraint( - addr_lo: M31, - is_half_access: M31, - ) -> M31 { + pub fn halfword_alignment_constraint(addr_lo: M31, is_half_access: M31) -> M31 { // addr_lo % 2 = addr_lo & 1 // Check low bit is 0 - + // Placeholder is_half_access * (addr_lo - addr_lo) } @@ -770,11 +729,11 @@ impl CpuAir { // Using limbs: rs1 = rs1_hi * 2^16 + rs1_lo, same for rs2 // Full product = (rs1_hi*2^16 + rs1_lo) * (rs2_hi*2^16 + rs2_lo) // = rs1_hi*rs2_hi*2^32 + rs1_hi*rs2_lo*2^16 + rs1_lo*rs2_hi*2^16 + rs1_lo*rs2_lo - + // We need auxiliary witnesses for intermediate products and carries // For now, simplified: check that reconstruction matches // Real implementation needs range checks on all limbs and carry propagation - + // Placeholder helper for tests; production constraints live in rv32im.rs rd_val_lo - (rs1_lo * rs2_lo) - product_hi_lo + product_hi_lo } @@ -803,7 +762,7 @@ impl CpuAir { ) -> M31 { // MULH returns upper 32 bits of signed 32x32->64 multiply // Needs sign extension logic and proper 64-bit computation - + // Placeholder helper for tests; production constraints live in rv32im.rs rd_val_lo - (rs1_hi * rs2_hi) - product_lo_lo + product_lo_lo } @@ -879,7 +838,7 @@ impl CpuAir { // Division constraint: dividend = divisor * quotient + remainder // rs1 = rs2 * quotient + remainder // Needs range check: |remainder| < |divisor| - + // Simplified reconstruction check (full implementation lives in rv32im.rs) // Placeholder: check basic reconstruction of low limb rs1_lo - (rs2_lo * quotient_lo + remainder_lo) @@ -907,7 +866,7 @@ impl CpuAir { ) -> M31 { // Unsigned division: simpler than signed // rs1 = rs2 * quotient + remainder, with remainder < rs2 - + // Placeholder rs1_lo - (rs2_lo * quotient_lo + remainder_lo) } @@ -935,7 +894,7 @@ impl CpuAir { ) -> M31 { // Same as DIV but remainder is the result // rs1 = rs2 * quotient + remainder - + // Placeholder rs1_lo - (rs2_lo * quotient_lo + remainder_lo) } @@ -962,7 +921,7 @@ impl CpuAir { ) -> M31 { // Unsigned remainder // rs1 = rs2 * quotient + remainder, with remainder < rs2 - + // Placeholder rs1_lo - (rs2_lo * quotient_lo + remainder_lo) } @@ -1000,23 +959,23 @@ impl CpuAir { // eq_result = 1 iff equal // branch_taken = eq_result // next_pc = branch_taken ? (pc + offset) : (pc + 4) - + // Constraint 1: branch_taken = eq_result let c1 = branch_taken - eq_result; - + // Constraint 2: eq_result is binary let c2 = eq_result * (M31::ONE - eq_result); - + // Constraint 3: If eq_result=1, then diff must be zero let diff_lo = rs1_lo - rs2_lo; let diff_hi = rs1_hi - rs2_hi; let c3 = eq_result * (diff_lo + diff_hi); - + // Constraint 4: PC update let four = M31::new(4); let expected_pc = branch_taken * (pc + offset) + (M31::ONE - branch_taken) * (pc + four); let c4 = next_pc - expected_pc; - + // Combine constraints (simplified - in practice would return array) c1 + c2 + c3 + c4 } @@ -1035,20 +994,20 @@ impl CpuAir { ) -> M31 { // branch_taken = 1 iff rs1 != rs2 // ne_result = 1 - eq_result - + let diff_lo = rs1_lo - rs2_lo; let diff_hi = rs1_hi - rs2_hi; - + // If ne_result=1, at least one diff must be non-zero // If ne_result=0, both diffs must be zero let c1 = (M31::ONE - ne_result) * (diff_lo + diff_hi); let c2 = ne_result * (M31::ONE - ne_result); // Binary let c3 = branch_taken - ne_result; - + let four = M31::new(4); let expected_pc = branch_taken * (pc + offset) + (M31::ONE - branch_taken) * (pc + four); let c4 = next_pc - expected_pc; - + c1 + c2 + c3 + c4 } @@ -1066,14 +1025,14 @@ impl CpuAir { ) -> M31 { // Reuse signed comparison logic // branch_taken = lt_result - + let c1 = branch_taken - lt_result; let c2 = lt_result * (M31::ONE - lt_result); // Binary - + let four = M31::new(4); let expected_pc = branch_taken * (pc + offset) + (M31::ONE - branch_taken) * (pc + four); let c3 = next_pc - expected_pc; - + // Placeholder helper for tests; production constraints live in rv32im.rs c1 + c2 + c3 } @@ -1093,11 +1052,11 @@ impl CpuAir { // ge_result = 1 - lt_result let c1 = branch_taken - ge_result; let c2 = ge_result * (M31::ONE - ge_result); - + let four = M31::new(4); let expected_pc = branch_taken * (pc + offset) + (M31::ONE - branch_taken) * (pc + four); let c3 = next_pc - expected_pc; - + c1 + c2 + c3 } @@ -1116,11 +1075,11 @@ impl CpuAir { // Use unsigned comparison (borrow detection) let c1 = branch_taken - ltu_result; let c2 = ltu_result * (M31::ONE - ltu_result); - + let four = M31::new(4); let expected_pc = branch_taken * (pc + offset) + (M31::ONE - branch_taken) * (pc + four); let c3 = next_pc - expected_pc; - + c1 + c2 + c3 } @@ -1140,11 +1099,11 @@ impl CpuAir { let geu_result = M31::ONE - lt_result; let c1 = branch_taken - geu_result; let c2 = geu_result * (M31::ONE - geu_result); - + let four = M31::new(4); let expected_pc = branch_taken * (pc + offset) + (M31::ONE - branch_taken) * (pc + four); let c3 = next_pc - expected_pc; - + c1 + c2 + c3 } @@ -1159,19 +1118,14 @@ impl CpuAir { /// /// # Returns /// Constraints ensuring correct JAL behavior - pub fn jal_constraint( - pc: M31, - next_pc: M31, - rd_val: M31, - offset: M31, - ) -> M31 { + pub fn jal_constraint(pc: M31, next_pc: M31, rd_val: M31, offset: M31) -> M31 { // Constraint 1: rd = pc + 4 let four = M31::new(4); let c1 = rd_val - (pc + four); - + // Constraint 2: next_pc = pc + offset let c2 = next_pc - (pc + offset); - + c1 + c2 } @@ -1187,22 +1141,16 @@ impl CpuAir { /// /// # Returns /// Constraints ensuring correct JALR behavior - pub fn jalr_constraint( - pc: M31, - rs1_val: M31, - next_pc: M31, - rd_val: M31, - offset: M31, - ) -> M31 { + pub fn jalr_constraint(pc: M31, rs1_val: M31, next_pc: M31, rd_val: M31, offset: M31) -> M31 { // Constraint 1: rd = pc + 4 let four = M31::new(4); let c1 = rd_val - (pc + four); - + // Constraint 2: next_pc = (rs1 + offset) & ~1 // The LSB masking ensures PC is always aligned // Simplified helper: assume next_pc = rs1 + offset (alignment checked in rv32im.rs) let c2 = next_pc - (rs1_val + offset); - + // Placeholder helper for tests; production constraints implemented in rv32im.rs c1 + c2 } @@ -1240,7 +1188,7 @@ mod tests { let bits = u32_to_bits(value); let constraints = CpuAir::bit_decomposition_constraints(lo, hi, &bits); - + // All 34 constraints should be satisfied (= 0) assert_eq!(constraints.len(), 34); for (i, constraint) in constraints.iter().enumerate() { @@ -1255,7 +1203,7 @@ mod tests { let bits = u32_to_bits(value); let constraints = CpuAir::bit_decomposition_constraints(lo, hi, &bits); - + for constraint in constraints { assert_eq!(constraint, M31::ZERO); } @@ -1268,7 +1216,7 @@ mod tests { let bits = u32_to_bits(value); let constraints = CpuAir::bit_decomposition_constraints(lo, hi, &bits); - + for constraint in constraints { assert_eq!(constraint, M31::ZERO); } @@ -1286,7 +1234,7 @@ mod tests { let bits_result = u32_to_bits(result); let constraints = CpuAir::bitwise_and_constraints(&bits_a, &bits_b, &bits_result); - + assert_eq!(constraints.len(), 32); for constraint in constraints { assert_eq!(constraint, M31::ZERO); @@ -1309,10 +1257,16 @@ mod tests { let bits_result = u32_to_bits(expected); let constraints = CpuAir::bitwise_and_constraints(&bits_a, &bits_b, &bits_result); - + for (i, constraint) in constraints.iter().enumerate() { - assert_eq!(*constraint, M31::ZERO, - "AND failed for case ({:#x}, {:#x}), bit {}", a, b, i); + assert_eq!( + *constraint, + M31::ZERO, + "AND failed for case ({:#x}, {:#x}), bit {}", + a, + b, + i + ); } } } @@ -1329,7 +1283,7 @@ mod tests { let bits_result = u32_to_bits(result); let constraints = CpuAir::bitwise_or_constraints(&bits_a, &bits_b, &bits_result); - + assert_eq!(constraints.len(), 32); for constraint in constraints { assert_eq!(constraint, M31::ZERO); @@ -1351,10 +1305,16 @@ mod tests { let bits_result = u32_to_bits(expected); let constraints = CpuAir::bitwise_or_constraints(&bits_a, &bits_b, &bits_result); - + for (i, constraint) in constraints.iter().enumerate() { - assert_eq!(*constraint, M31::ZERO, - "OR failed for case ({:#x}, {:#x}), bit {}", a, b, i); + assert_eq!( + *constraint, + M31::ZERO, + "OR failed for case ({:#x}, {:#x}), bit {}", + a, + b, + i + ); } } } @@ -1371,7 +1331,7 @@ mod tests { let bits_result = u32_to_bits(result); let constraints = CpuAir::bitwise_xor_constraints(&bits_a, &bits_b, &bits_result); - + assert_eq!(constraints.len(), 32); for constraint in constraints { assert_eq!(constraint, M31::ZERO); @@ -1393,10 +1353,16 @@ mod tests { let bits_result = u32_to_bits(expected); let constraints = CpuAir::bitwise_xor_constraints(&bits_a, &bits_b, &bits_result); - + for (i, constraint) in constraints.iter().enumerate() { - assert_eq!(*constraint, M31::ZERO, - "XOR failed for case ({:#x}, {:#x}), bit {}", a, b, i); + assert_eq!( + *constraint, + M31::ZERO, + "XOR failed for case ({:#x}, {:#x}), bit {}", + a, + b, + i + ); } } } @@ -1413,7 +1379,7 @@ mod tests { let bits_wrong = u32_to_bits(wrong_result); let constraints = CpuAir::bitwise_and_constraints(&bits_a, &bits_b, &bits_wrong); - + // Should have non-zero constraints let has_nonzero = constraints.iter().any(|c| *c != M31::ZERO); assert!(has_nonzero, "Constraint should catch incorrect AND result"); @@ -1425,15 +1391,22 @@ mod tests { let value = 0x12345678u32; let (lo, hi) = u32_to_limbs(value); let mut bits = u32_to_bits(value); - + // Flip a bit - bits[5] = if bits[5] == M31::ZERO { M31::ONE } else { M31::ZERO }; + bits[5] = if bits[5] == M31::ZERO { + M31::ONE + } else { + M31::ZERO + }; let constraints = CpuAir::bit_decomposition_constraints(lo, hi, &bits); - + // Should have non-zero constraints (reconstruction will fail) let has_nonzero = constraints.iter().any(|c| *c != M31::ZERO); - assert!(has_nonzero, "Constraint should catch incorrect bit decomposition"); + assert!( + has_nonzero, + "Constraint should catch incorrect bit decomposition" + ); } #[test] @@ -1447,11 +1420,8 @@ mod tests { let bits_result = u32_to_bits(expected); let shift_m31 = M31::new(shift); - let constraints = CpuAir::shift_left_logical_constraints( - &bits_value, - &bits_result, - shift_m31, - ); + let constraints = + CpuAir::shift_left_logical_constraints(&bits_value, &bits_result, shift_m31); assert_eq!(constraints.len(), 32); for (i, constraint) in constraints.iter().enumerate() { @@ -1475,16 +1445,17 @@ mod tests { let bits_result = u32_to_bits(expected); let shift_m31 = M31::new(shift); - let constraints = CpuAir::shift_left_logical_constraints( - &bits_value, - &bits_result, - shift_m31, - ); + let constraints = + CpuAir::shift_left_logical_constraints(&bits_value, &bits_result, shift_m31); for (i, constraint) in constraints.iter().enumerate() { assert_eq!( - *constraint, M31::ZERO, - "SLL({:#x} << {}) failed at bit {}", value, shift, i + *constraint, + M31::ZERO, + "SLL({:#x} << {}) failed at bit {}", + value, + shift, + i ); } } @@ -1501,11 +1472,8 @@ mod tests { let bits_result = u32_to_bits(expected); let shift_m31 = M31::new(shift); - let constraints = CpuAir::shift_right_logical_constraints( - &bits_value, - &bits_result, - shift_m31, - ); + let constraints = + CpuAir::shift_right_logical_constraints(&bits_value, &bits_result, shift_m31); assert_eq!(constraints.len(), 32); for constraint in constraints { @@ -1529,16 +1497,17 @@ mod tests { let bits_result = u32_to_bits(expected); let shift_m31 = M31::new(shift); - let constraints = CpuAir::shift_right_logical_constraints( - &bits_value, - &bits_result, - shift_m31, - ); + let constraints = + CpuAir::shift_right_logical_constraints(&bits_value, &bits_result, shift_m31); for (i, constraint) in constraints.iter().enumerate() { assert_eq!( - *constraint, M31::ZERO, - "SRL({:#x} >> {}) failed at bit {}", value, shift, i + *constraint, + M31::ZERO, + "SRL({:#x} >> {}) failed at bit {}", + value, + shift, + i ); } } @@ -1555,11 +1524,8 @@ mod tests { let bits_result = u32_to_bits(expected); let shift_m31 = M31::new(shift); - let constraints = CpuAir::shift_right_arithmetic_constraints( - &bits_value, - &bits_result, - shift_m31, - ); + let constraints = + CpuAir::shift_right_arithmetic_constraints(&bits_value, &bits_result, shift_m31); assert_eq!(constraints.len(), 32); for constraint in constraints { @@ -1578,11 +1544,8 @@ mod tests { let bits_result = u32_to_bits(expected); let shift_m31 = M31::new(shift); - let constraints = CpuAir::shift_right_arithmetic_constraints( - &bits_value, - &bits_result, - shift_m31, - ); + let constraints = + CpuAir::shift_right_arithmetic_constraints(&bits_value, &bits_result, shift_m31); for constraint in constraints { assert_eq!(constraint, M31::ZERO, "SRA sign extension failed"); @@ -1593,12 +1556,12 @@ mod tests { fn test_shift_right_arithmetic_comprehensive() { let test_cases = [ // (value, shift, expected_sra) - (0x00000008, 1, 0x00000004), // Positive: 8 >> 1 = 4 - (0x00000008, 2, 0x00000002), // Positive: 8 >> 2 = 2 - (0xFFFFFFF8u32, 1, 0xFFFFFFFCu32), // Negative: -8 >> 1 = -4 (sign extend) - (0xFFFFFFF8u32, 2, 0xFFFFFFFEu32), // Negative: -8 >> 2 = -2 (sign extend) + (0x00000008, 1, 0x00000004), // Positive: 8 >> 1 = 4 + (0x00000008, 2, 0x00000002), // Positive: 8 >> 2 = 2 + (0xFFFFFFF8u32, 1, 0xFFFFFFFCu32), // Negative: -8 >> 1 = -4 (sign extend) + (0xFFFFFFF8u32, 2, 0xFFFFFFFEu32), // Negative: -8 >> 2 = -2 (sign extend) (0x80000000u32, 31, 0xFFFFFFFFu32), // Min int >> 31 = -1 (all ones) - (0x7FFFFFFF, 31, 0x00000000), // Max int >> 31 = 0 + (0x7FFFFFFF, 31, 0x00000000), // Max int >> 31 = 0 ]; for (value, shift, expected) in test_cases { @@ -1606,17 +1569,18 @@ mod tests { let bits_result = u32_to_bits(expected); let shift_m31 = M31::new(shift); - let constraints = CpuAir::shift_right_arithmetic_constraints( - &bits_value, - &bits_result, - shift_m31, - ); + let constraints = + CpuAir::shift_right_arithmetic_constraints(&bits_value, &bits_result, shift_m31); for (i, constraint) in constraints.iter().enumerate() { assert_eq!( - *constraint, M31::ZERO, + *constraint, + M31::ZERO, "SRA({:#x} >> {}) failed at bit {}, expected {:#x}", - value, shift, i, expected + value, + shift, + i, + expected ); } } @@ -1633,24 +1597,24 @@ mod tests { let bits_wrong = u32_to_bits(wrong_result); let shift_m31 = M31::new(shift); - let constraints = CpuAir::shift_left_logical_constraints( - &bits_value, - &bits_wrong, - shift_m31, - ); + let constraints = + CpuAir::shift_left_logical_constraints(&bits_value, &bits_wrong, shift_m31); let has_nonzero = constraints.iter().any(|c| *c != M31::ZERO); - assert!(has_nonzero, "Constraint should catch incorrect shift result"); + assert!( + has_nonzero, + "Constraint should catch incorrect shift result" + ); } #[test] fn test_set_less_than_unsigned() { // Test SLTU: unsigned comparison let test_cases = [ - (5u32, 10u32, 1u32, 1u32), // 5 < 10 = true, borrow = 1 - (10u32, 5u32, 0u32, 0u32), // 10 < 5 = false, borrow = 0 - (5u32, 5u32, 0u32, 0u32), // 5 < 5 = false, borrow = 0 - (0u32, 1u32, 1u32, 1u32), // 0 < 1 = true, borrow = 1 + (5u32, 10u32, 1u32, 1u32), // 5 < 10 = true, borrow = 1 + (10u32, 5u32, 0u32, 0u32), // 10 < 5 = false, borrow = 0 + (5u32, 5u32, 0u32, 0u32), // 5 < 5 = false, borrow = 0 + (0u32, 1u32, 1u32, 1u32), // 0 < 1 = true, borrow = 1 (0xFFFFFFFFu32, 0u32, 0u32, 0u32), // max < 0 = false (unsigned) ]; @@ -1660,17 +1624,17 @@ mod tests { let result = M31::new(expected_result); let borrow = M31::new(expected_borrow); - let constraints = CpuAir::set_less_than_unsigned_constraints( - &bits_a, - &bits_b, - result, - borrow, - ); + let constraints = + CpuAir::set_less_than_unsigned_constraints(&bits_a, &bits_b, result, borrow); for (i, constraint) in constraints.iter().enumerate() { assert_eq!( - *constraint, M31::ZERO, - "SLTU({} < {}) failed at constraint {}", a, b, i + *constraint, + M31::ZERO, + "SLTU({} < {}) failed at constraint {}", + a, + b, + i ); } } @@ -1680,23 +1644,19 @@ mod tests { fn test_set_less_than_signed_same_sign() { // Test SLT with same sign (both positive or both negative) // When signs are same, compare magnitudes via subtraction - + // Case 1: Both positive let a = 5u32; let b = 10u32; let diff = (a.wrapping_sub(b)) as u32; // Will have sign bit set - + let bits_a = u32_to_bits(a); let bits_b = u32_to_bits(b); let diff_bits = u32_to_bits(diff); let result = M31::ONE; // 5 < 10 = true - let constraints = CpuAir::set_less_than_signed_constraints( - &bits_a, - &bits_b, - result, - &diff_bits, - ); + let constraints = + CpuAir::set_less_than_signed_constraints(&bits_a, &bits_b, result, &diff_bits); for constraint in constraints { assert_eq!(constraint, M31::ZERO, "SLT(5 < 10) failed"); @@ -1708,23 +1668,19 @@ mod tests { // Test SLT with different signs // Negative < Positive = true // Positive < Negative = false - + // Case 1: negative < positive (true) let a = 0xFFFFFFFEu32; // -2 in two's complement - let b = 5u32; // +5 + let b = 5u32; // +5 let diff = a.wrapping_sub(b); - + let bits_a = u32_to_bits(a); let bits_b = u32_to_bits(b); let diff_bits = u32_to_bits(diff); let result = M31::ONE; // -2 < 5 = true - let constraints = CpuAir::set_less_than_signed_constraints( - &bits_a, - &bits_b, - result, - &diff_bits, - ); + let constraints = + CpuAir::set_less_than_signed_constraints(&bits_a, &bits_b, result, &diff_bits); for constraint in constraints { assert_eq!(constraint, M31::ZERO, "SLT(-2 < 5) failed"); @@ -1734,18 +1690,14 @@ mod tests { let a2 = 5u32; let b2 = 0xFFFFFFFEu32; // -2 let diff2 = a2.wrapping_sub(b2); - + let bits_a2 = u32_to_bits(a2); let bits_b2 = u32_to_bits(b2); let diff_bits2 = u32_to_bits(diff2); let result2 = M31::ZERO; // 5 < -2 = false - let constraints2 = CpuAir::set_less_than_signed_constraints( - &bits_a2, - &bits_b2, - result2, - &diff_bits2, - ); + let constraints2 = + CpuAir::set_less_than_signed_constraints(&bits_a2, &bits_b2, result2, &diff_bits2); for constraint in constraints2 { assert_eq!(constraint, M31::ZERO, "SLT(5 < -2) failed"); @@ -1756,7 +1708,7 @@ mod tests { fn test_sub_with_borrow() { // Test SUB constraint with borrow // Borrow occurs when low limb underflows: a_lo < b_lo - + // Case 1: 10 - 5 = 5, no borrow in limbs let a = 10u32; let b = 5u32; @@ -1764,7 +1716,11 @@ mod tests { let (b_lo, b_hi) = u32_to_limbs(b); let result = a.wrapping_sub(b); let (result_lo, result_hi) = u32_to_limbs(result); - let borrow = if a_lo.value() < b_lo.value() { M31::ONE } else { M31::ZERO }; + let borrow = if a_lo.value() < b_lo.value() { + M31::ONE + } else { + M31::ZERO + }; let (c_lo, c_hi) = CpuAir::sub_with_borrow_constraint( a_lo, a_hi, b_lo, b_hi, result_lo, result_hi, borrow, @@ -1779,13 +1735,29 @@ mod tests { let (b_lo2, b_hi2) = u32_to_limbs(b2); let result2 = a2.wrapping_sub(b2); let (result_lo2, result_hi2) = u32_to_limbs(result2); - let borrow2 = if a_lo2.value() < b_lo2.value() { M31::ONE } else { M31::ZERO }; + let borrow2 = if a_lo2.value() < b_lo2.value() { + M31::ONE + } else { + M31::ZERO + }; let (c_lo2, c_hi2) = CpuAir::sub_with_borrow_constraint( a_lo2, a_hi2, b_lo2, b_hi2, result_lo2, result_hi2, borrow2, ); - assert_eq!(c_lo2, M31::ZERO, "SUB({:#x} - {:#x}) low limb failed", a2, b2); - assert_eq!(c_hi2, M31::ZERO, "SUB({:#x} - {:#x}) high limb failed", a2, b2); + assert_eq!( + c_lo2, + M31::ZERO, + "SUB({:#x} - {:#x}) low limb failed", + a2, + b2 + ); + assert_eq!( + c_hi2, + M31::ZERO, + "SUB({:#x} - {:#x}) high limb failed", + a2, + b2 + ); // Case 3: 0x20000 - 0x10005 = 0xFFFB, requires borrow let a3 = 0x20000u32; @@ -1794,13 +1766,29 @@ mod tests { let (b_lo3, b_hi3) = u32_to_limbs(b3); let result3 = a3.wrapping_sub(b3); let (result_lo3, result_hi3) = u32_to_limbs(result3); - let borrow3 = if a_lo3.value() < b_lo3.value() { M31::ONE } else { M31::ZERO }; + let borrow3 = if a_lo3.value() < b_lo3.value() { + M31::ONE + } else { + M31::ZERO + }; let (c_lo3, c_hi3) = CpuAir::sub_with_borrow_constraint( a_lo3, a_hi3, b_lo3, b_hi3, result_lo3, result_hi3, borrow3, ); - assert_eq!(c_lo3, M31::ZERO, "SUB({:#x} - {:#x}) low limb failed", a3, b3); - assert_eq!(c_hi3, M31::ZERO, "SUB({:#x} - {:#x}) high limb failed", a3, b3); + assert_eq!( + c_lo3, + M31::ZERO, + "SUB({:#x} - {:#x}) low limb failed", + a3, + b3 + ); + assert_eq!( + c_hi3, + M31::ZERO, + "SUB({:#x} - {:#x}) high limb failed", + a3, + b3 + ); } #[test] @@ -1814,15 +1802,14 @@ mod tests { let bits_a = u32_to_bits(a); let bits_b = u32_to_bits(b); - let constraints = CpuAir::set_less_than_unsigned_constraints( - &bits_a, - &bits_b, - wrong_result, - borrow, - ); + let constraints = + CpuAir::set_less_than_unsigned_constraints(&bits_a, &bits_b, wrong_result, borrow); let has_nonzero = constraints.iter().any(|c| *c != M31::ZERO); - assert!(has_nonzero, "Constraint should catch incorrect comparison result"); + assert!( + has_nonzero, + "Constraint should catch incorrect comparison result" + ); } #[test] @@ -1835,14 +1822,13 @@ mod tests { let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (imm_lo, imm_hi) = u32_to_limbs(imm); let (result_lo, result_hi) = u32_to_limbs(expected); - + // No carry for this case let carry = M31::ZERO; let is_addi = M31::ONE; let (c_lo, c_hi) = CpuAir::addi_constraint( - is_addi, result_lo, result_hi, - rs1_lo, rs1_hi, imm_lo, imm_hi, carry + is_addi, result_lo, result_hi, rs1_lo, rs1_hi, imm_lo, imm_hi, carry, ); assert_eq!(c_lo, M31::ZERO, "ADDI low limb failed"); @@ -1909,7 +1895,7 @@ mod tests { let rs1 = 0xFFFFFFFEu32; // -2 let imm = 5u32; let diff = rs1.wrapping_sub(imm); - + let bits_rs1 = u32_to_bits(rs1); let bits_imm = u32_to_bits(imm); let diff_bits = u32_to_bits(diff); @@ -1927,7 +1913,7 @@ mod tests { // Test SLTIU: unsigned comparison with immediate let rs1 = 5u32; let imm = 10u32; - + let bits_rs1 = u32_to_bits(rs1); let bits_imm = u32_to_bits(imm); let result = M31::ONE; // 5 < 10 = true (unsigned) @@ -2012,12 +1998,18 @@ mod tests { let (imm_lo, imm_hi) = u32_to_limbs(imm); let (result_lo, result_hi) = u32_to_limbs(expected); let carry = M31::ZERO; - + let (c_lo, c_hi) = CpuAir::addi_constraint( - M31::ONE, result_lo, result_hi, - rs1_lo, rs1_hi, imm_lo, imm_hi, carry + M31::ONE, + result_lo, + result_hi, + rs1_lo, + rs1_hi, + imm_lo, + imm_hi, + carry, ); - + assert_eq!(c_lo, M31::ZERO, "ADDI({} + {}) failed", rs1, imm); assert_eq!(c_hi, M31::ZERO, "ADDI({} + {}) failed", rs1, imm); } @@ -2025,7 +2017,7 @@ mod tests { let bits_rs1 = u32_to_bits(rs1); let bits_imm = u32_to_bits(imm); let bits_result = u32_to_bits(expected); - + let constraints = CpuAir::andi_constraint(&bits_rs1, &bits_imm, &bits_result); for c in constraints { assert_eq!(c, M31::ZERO, "ANDI({:#x} & {:#x}) failed", rs1, imm); @@ -2035,7 +2027,7 @@ mod tests { let bits_rs1 = u32_to_bits(rs1); let bits_imm = u32_to_bits(imm); let bits_result = u32_to_bits(expected); - + let constraints = CpuAir::ori_constraint(&bits_rs1, &bits_imm, &bits_result); for c in constraints { assert_eq!(c, M31::ZERO, "ORI({:#x} | {:#x}) failed", rs1, imm); @@ -2045,7 +2037,7 @@ mod tests { let bits_rs1 = u32_to_bits(rs1); let bits_imm = u32_to_bits(imm); let bits_result = u32_to_bits(expected); - + let constraints = CpuAir::xori_constraint(&bits_rs1, &bits_imm, &bits_result); for c in constraints { assert_eq!(c, M31::ZERO, "XORI({:#x} ^ {:#x}) failed", rs1, imm); @@ -2068,7 +2060,11 @@ mod tests { // Test with wrong value let wrong_rd = M31::new(0x11111111); let wrong_constraint = CpuAir::load_word_constraint(mem_value, wrong_rd); - assert_ne!(wrong_constraint, M31::ZERO, "LW should catch incorrect value"); + assert_ne!( + wrong_constraint, + M31::ZERO, + "LW should catch incorrect value" + ); } #[test] @@ -2083,7 +2079,11 @@ mod tests { // Test with wrong stored value let wrong_mem = M31::new(0xABCDEF01); let wrong_constraint = CpuAir::store_word_constraint(wrong_mem, rs2_val); - assert_ne!(wrong_constraint, M31::ZERO, "SW should catch incorrect stored value"); + assert_ne!( + wrong_constraint, + M31::ZERO, + "SW should catch incorrect stored value" + ); } #[test] @@ -2138,7 +2138,11 @@ mod tests { let is_half = M31::ONE; let constraint = CpuAir::halfword_alignment_constraint(aligned_addr, is_half); - assert_eq!(constraint, M31::ZERO, "Halfword alignment constraint failed"); + assert_eq!( + constraint, + M31::ZERO, + "Halfword alignment constraint failed" + ); } // ============================================================================ @@ -2160,14 +2164,7 @@ mod tests { let (prod_hi_lo, prod_hi_hi) = u32_to_limbs(product_hi); let constraint = CpuAir::mul_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - rd_lo, - rd_hi, - prod_hi_lo, - prod_hi_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, rd_lo, rd_hi, prod_hi_lo, prod_hi_hi, ); // Placeholder implementation - just verify it compiles @@ -2189,14 +2186,7 @@ mod tests { let (prod_hi_lo, prod_hi_hi) = u32_to_limbs(product_hi); let constraint = CpuAir::mul_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - rd_lo, - rd_hi, - prod_hi_lo, - prod_hi_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, rd_lo, rd_hi, prod_hi_lo, prod_hi_hi, ); assert_eq!(constraint, M31::ZERO, "MUL large numbers"); @@ -2217,14 +2207,7 @@ mod tests { let (prod_lo_lo, prod_lo_hi) = u32_to_limbs(product_lo); let constraint = CpuAir::mulh_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - rd_lo, - rd_hi, - prod_lo_lo, - prod_lo_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, rd_lo, rd_hi, prod_lo_lo, prod_lo_hi, ); // Placeholder - verify compilation @@ -2245,14 +2228,7 @@ mod tests { let (prod_lo_lo, prod_lo_hi) = u32_to_limbs((product & 0xFFFFFFFF) as u32); let constraint = CpuAir::mulhu_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - rd_lo, - rd_hi, - prod_lo_lo, - prod_lo_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, rd_lo, rd_hi, prod_lo_lo, prod_lo_hi, ); let _ = constraint; @@ -2272,14 +2248,7 @@ mod tests { let (rem_lo, rem_hi) = u32_to_limbs(remainder); let constraint = CpuAir::div_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - quot_lo, - quot_hi, - rem_lo, - rem_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, quot_lo, quot_hi, rem_lo, rem_hi, ); assert_eq!(constraint, M31::ZERO, "DIV basic constraint"); @@ -2299,14 +2268,7 @@ mod tests { let (rem_lo, rem_hi) = u32_to_limbs(remainder); let constraint = CpuAir::div_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - quot_lo, - quot_hi, - rem_lo, - rem_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, quot_lo, quot_hi, rem_lo, rem_hi, ); // Placeholder - simplified limb check doesn't handle carries properly @@ -2328,14 +2290,7 @@ mod tests { let (rem_lo, rem_hi) = u32_to_limbs(remainder); let constraint = CpuAir::divu_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - quot_lo, - quot_hi, - rem_lo, - rem_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, quot_lo, quot_hi, rem_lo, rem_hi, ); // Placeholder - simplified limb check doesn't handle carries @@ -2356,14 +2311,7 @@ mod tests { let (rem_lo, rem_hi) = u32_to_limbs(remainder); let constraint = CpuAir::rem_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - quot_lo, - quot_hi, - rem_lo, - rem_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, quot_lo, quot_hi, rem_lo, rem_hi, ); assert_eq!(constraint, M31::ZERO, "REM basic constraint"); @@ -2383,14 +2331,7 @@ mod tests { let (rem_lo, rem_hi) = u32_to_limbs(remainder); let constraint = CpuAir::remu_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - quot_lo, - quot_hi, - rem_lo, - rem_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, quot_lo, quot_hi, rem_lo, rem_hi, ); // Placeholder - simplified limb check doesn't handle carries @@ -2412,14 +2353,7 @@ mod tests { let (prod_hi_lo, prod_hi_hi) = u32_to_limbs((correct_product >> 32) as u32); let constraint = CpuAir::mul_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - rd_lo, - rd_hi, - prod_hi_lo, - prod_hi_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, rd_lo, rd_hi, prod_hi_lo, prod_hi_hi, ); // Placeholder won't catch this yet, but verify it compiles @@ -2440,14 +2374,7 @@ mod tests { let (rem_lo, rem_hi) = u32_to_limbs(remainder); let constraint = CpuAir::div_constraint( - rs1_lo, - rs1_hi, - rs2_lo, - rs2_hi, - quot_lo, - quot_hi, - rem_lo, - rem_hi, + rs1_lo, rs1_hi, rs2_lo, rs2_hi, quot_lo, quot_hi, rem_lo, rem_hi, ); // Should detect incorrect quotient (when fully implemented) @@ -2466,18 +2393,25 @@ mod tests { let rs2 = 0x12345678u32; let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (rs2_lo, rs2_hi) = u32_to_limbs(rs2); - + let eq_result = M31::ONE; // Equal let branch_taken = M31::ONE; // Branch taken let pc = M31::new(0x1000); let offset = M31::new(0x100); // Branch offset let next_pc = M31::new(0x1100); // pc + offset - + let constraint = CpuAir::beq_constraint( - rs1_lo, rs1_hi, rs2_lo, rs2_hi, - eq_result, branch_taken, pc, next_pc, offset, + rs1_lo, + rs1_hi, + rs2_lo, + rs2_hi, + eq_result, + branch_taken, + pc, + next_pc, + offset, ); - + assert_eq!(constraint, M31::ZERO, "BEQ taken constraint failed"); } @@ -2488,18 +2422,25 @@ mod tests { let rs2 = 0x12345679u32; // Different let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (rs2_lo, rs2_hi) = u32_to_limbs(rs2); - + let eq_result = M31::ZERO; // Not equal let branch_taken = M31::ZERO; // Branch not taken let pc = M31::new(0x1000); let offset = M31::new(0x100); let next_pc = M31::new(0x1004); // pc + 4 - + let constraint = CpuAir::beq_constraint( - rs1_lo, rs1_hi, rs2_lo, rs2_hi, - eq_result, branch_taken, pc, next_pc, offset, + rs1_lo, + rs1_hi, + rs2_lo, + rs2_hi, + eq_result, + branch_taken, + pc, + next_pc, + offset, ); - + assert_eq!(constraint, M31::ZERO, "BEQ not taken constraint failed"); } @@ -2510,18 +2451,25 @@ mod tests { let rs2 = 0x1234u32; let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (rs2_lo, rs2_hi) = u32_to_limbs(rs2); - + let ne_result = M31::ONE; // Not equal let branch_taken = M31::ONE; let pc = M31::new(0x2000); let offset = M31::new(0x50); let next_pc = M31::new(0x2050); // pc + offset - + let constraint = CpuAir::bne_constraint( - rs1_lo, rs1_hi, rs2_lo, rs2_hi, - ne_result, branch_taken, pc, next_pc, offset, + rs1_lo, + rs1_hi, + rs2_lo, + rs2_hi, + ne_result, + branch_taken, + pc, + next_pc, + offset, ); - + assert_eq!(constraint, M31::ZERO, "BNE taken constraint failed"); } @@ -2532,18 +2480,25 @@ mod tests { let rs2 = 50u32; let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (rs2_lo, rs2_hi) = u32_to_limbs(rs2); - + let lt_result = M31::ONE; // rs1 < rs2 let branch_taken = M31::ONE; let pc = M31::new(0x3000); let offset = M31::new(0x200); let next_pc = M31::new(0x3200); - + let constraint = CpuAir::blt_constraint( - rs1_lo, rs1_hi, rs2_lo, rs2_hi, - lt_result, branch_taken, pc, next_pc, offset, + rs1_lo, + rs1_hi, + rs2_lo, + rs2_hi, + lt_result, + branch_taken, + pc, + next_pc, + offset, ); - + assert_eq!(constraint, M31::ZERO, "BLT taken constraint failed"); } @@ -2554,18 +2509,25 @@ mod tests { let rs2 = 20u32; let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (rs2_lo, rs2_hi) = u32_to_limbs(rs2); - + let ge_result = M31::ZERO; // rs1 < rs2, so NOT >= let branch_taken = M31::ZERO; let pc = M31::new(0x4000); let offset = M31::new(0x80); let next_pc = M31::new(0x4004); // pc + 4 - + let constraint = CpuAir::bge_constraint( - rs1_lo, rs1_hi, rs2_lo, rs2_hi, - ge_result, branch_taken, pc, next_pc, offset, + rs1_lo, + rs1_hi, + rs2_lo, + rs2_hi, + ge_result, + branch_taken, + pc, + next_pc, + offset, ); - + assert_eq!(constraint, M31::ZERO, "BGE not taken constraint failed"); } @@ -2576,18 +2538,25 @@ mod tests { let rs2 = 100u32; let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (rs2_lo, rs2_hi) = u32_to_limbs(rs2); - + let ltu_result = M31::ONE; // 5 < 100 (unsigned) let branch_taken = M31::ONE; let pc = M31::new(0x5000); let offset = M31::new(0x40); let next_pc = M31::new(0x5040); - + let constraint = CpuAir::bltu_constraint( - rs1_lo, rs1_hi, rs2_lo, rs2_hi, - ltu_result, branch_taken, pc, next_pc, offset, + rs1_lo, + rs1_hi, + rs2_lo, + rs2_hi, + ltu_result, + branch_taken, + pc, + next_pc, + offset, ); - + assert_eq!(constraint, M31::ZERO, "BLTU taken constraint failed"); } @@ -2598,18 +2567,25 @@ mod tests { let rs2 = 0xFFFFu32; let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (rs2_lo, rs2_hi) = u32_to_limbs(rs2); - + let lt_result = M31::ZERO; // Equal, so not less-than (geu = 1 - 0 = 1) let branch_taken = M31::ONE; let pc = M31::new(0x6000); let offset = M31::new(0x10); let next_pc = M31::new(0x6010); - + let constraint = CpuAir::bgeu_constraint( - rs1_lo, rs1_hi, rs2_lo, rs2_hi, - lt_result, branch_taken, pc, next_pc, offset, + rs1_lo, + rs1_hi, + rs2_lo, + rs2_hi, + lt_result, + branch_taken, + pc, + next_pc, + offset, ); - + assert_eq!(constraint, M31::ZERO, "BGEU taken constraint failed"); } @@ -2620,9 +2596,9 @@ mod tests { let offset = M31::new(0x200); let next_pc = M31::new(0x1200); // pc + offset let rd_val = M31::new(0x1004); // pc + 4 - + let constraint = CpuAir::jal_constraint(pc, next_pc, rd_val, offset); - + assert_eq!(constraint, M31::ZERO, "JAL constraint failed"); } @@ -2633,9 +2609,9 @@ mod tests { let offset = M31::new(0x200); let next_pc = M31::new(0x1200); let wrong_rd = M31::new(0x1008); // Wrong link value - + let constraint = CpuAir::jal_constraint(pc, next_pc, wrong_rd, offset); - + assert_ne!(constraint, M31::ZERO, "JAL should catch incorrect link"); } @@ -2647,9 +2623,9 @@ mod tests { let offset = M31::new(0x100); let next_pc = M31::new(0x5100); // rs1 + offset let rd_val = M31::new(0x2004); // pc + 4 - + let constraint = CpuAir::jalr_constraint(pc, rs1_val, next_pc, rd_val, offset); - + assert_eq!(constraint, M31::ZERO, "JALR constraint failed"); } @@ -2661,9 +2637,9 @@ mod tests { let offset = M31::new(0x100); let wrong_next_pc = M31::new(0x5200); // Incorrect target let rd_val = M31::new(0x2004); - + let constraint = CpuAir::jalr_constraint(pc, rs1_val, wrong_next_pc, rd_val, offset); - + assert_ne!(constraint, M31::ZERO, "JALR should catch incorrect target"); } @@ -2674,19 +2650,26 @@ mod tests { let rs2 = 200u32; let (rs1_lo, rs1_hi) = u32_to_limbs(rs1); let (rs2_lo, rs2_hi) = u32_to_limbs(rs2); - + // BEQ with rs1 != rs2 but claiming equality let wrong_eq = M31::ONE; // Wrong: they're not equal let branch_taken = M31::ONE; let pc = M31::new(0x1000); let offset = M31::new(0x100); let next_pc = M31::new(0x1100); - + let constraint = CpuAir::beq_constraint( - rs1_lo, rs1_hi, rs2_lo, rs2_hi, - wrong_eq, branch_taken, pc, next_pc, offset, + rs1_lo, + rs1_hi, + rs2_lo, + rs2_hi, + wrong_eq, + branch_taken, + pc, + next_pc, + offset, ); - + // Should fail because eq_result doesn't match actual equality assert_ne!(constraint, M31::ZERO, "Should detect incorrect eq_result"); } diff --git a/crates/air/src/lib.rs b/crates/air/src/lib.rs index 31ce791..c98dc43 100644 --- a/crates/air/src/lib.rs +++ b/crates/air/src/lib.rs @@ -2,10 +2,10 @@ //! //! All constraints are kept at degree ≤ 2 for efficient STARK proving. +pub mod constraints; pub mod cpu; pub mod memory; -pub mod constraints; pub mod rv32im; pub use constraints::{AirConstraint, ConstraintSet}; -pub use rv32im::{Rv32imAir, CpuTraceRow, ConstraintEvaluator, Constraint}; +pub use rv32im::{Constraint, ConstraintEvaluator, CpuTraceRow, Rv32imAir}; diff --git a/crates/air/src/memory.rs b/crates/air/src/memory.rs index 44513b8..564bd73 100644 --- a/crates/air/src/memory.rs +++ b/crates/air/src/memory.rs @@ -12,7 +12,7 @@ impl MemoryAir { /// where fingerprint = α³·addr + α²·value + α·timestamp + is_write /// /// This constraint checks the running sum increment is correct. - /// + /// /// # Arguments /// * `addr` - Memory address /// * `value` - Memory value (32-bit) @@ -22,7 +22,7 @@ impl MemoryAir { /// * `curr_sum` - Running sum at current row /// * `alpha` - Challenge for fingerprint combination /// * `beta` - Challenge for denominator shift - /// + /// /// # Returns /// Constraint: (fingerprint + beta) * (curr_sum - prev_sum) - 1 = 0 #[inline] @@ -40,7 +40,7 @@ impl MemoryAir { let alpha2 = alpha * alpha; let alpha3 = alpha2 * alpha; let fingerprint = alpha3 * addr + alpha2 * value + alpha * timestamp + is_write; - + // LogUp increment: 1/(fingerprint + beta) // Constraint: (fingerprint + beta) * (curr_sum - prev_sum) = 1 // Rearranged: (fingerprint + beta) * (curr_sum - prev_sum) - 1 = 0 diff --git a/crates/air/src/rv32im.rs b/crates/air/src/rv32im.rs index 8e98498..a7e7b37 100644 --- a/crates/air/src/rv32im.rs +++ b/crates/air/src/rv32im.rs @@ -103,69 +103,218 @@ impl Rv32imAir { pub fn new() -> Self { let constraints = vec![ // Basic constraints - Constraint { name: "x0_zero", degree: 1, index: 0 }, - Constraint { name: "pc_increment", degree: 2, index: 1 }, - + Constraint { + name: "x0_zero", + degree: 1, + index: 0, + }, + Constraint { + name: "pc_increment", + degree: 2, + index: 1, + }, // R-type arithmetic - Constraint { name: "add", degree: 2, index: 2 }, - Constraint { name: "sub", degree: 2, index: 3 }, - Constraint { name: "and", degree: 2, index: 4 }, - Constraint { name: "or", degree: 2, index: 5 }, - Constraint { name: "xor", degree: 2, index: 6 }, - Constraint { name: "sll", degree: 2, index: 7 }, - Constraint { name: "srl", degree: 2, index: 8 }, - Constraint { name: "sra", degree: 2, index: 9 }, - Constraint { name: "slt", degree: 2, index: 10 }, - Constraint { name: "sltu", degree: 2, index: 11 }, - + Constraint { + name: "add", + degree: 2, + index: 2, + }, + Constraint { + name: "sub", + degree: 2, + index: 3, + }, + Constraint { + name: "and", + degree: 2, + index: 4, + }, + Constraint { + name: "or", + degree: 2, + index: 5, + }, + Constraint { + name: "xor", + degree: 2, + index: 6, + }, + Constraint { + name: "sll", + degree: 2, + index: 7, + }, + Constraint { + name: "srl", + degree: 2, + index: 8, + }, + Constraint { + name: "sra", + degree: 2, + index: 9, + }, + Constraint { + name: "slt", + degree: 2, + index: 10, + }, + Constraint { + name: "sltu", + degree: 2, + index: 11, + }, // I-type arithmetic - Constraint { name: "addi", degree: 2, index: 12 }, - Constraint { name: "andi", degree: 2, index: 13 }, - Constraint { name: "ori", degree: 2, index: 14 }, - Constraint { name: "xori", degree: 2, index: 15 }, - Constraint { name: "slti", degree: 2, index: 16 }, - Constraint { name: "sltiu", degree: 2, index: 17 }, - Constraint { name: "slli", degree: 2, index: 18 }, - Constraint { name: "srli", degree: 2, index: 19 }, - Constraint { name: "srai", degree: 2, index: 20 }, - + Constraint { + name: "addi", + degree: 2, + index: 12, + }, + Constraint { + name: "andi", + degree: 2, + index: 13, + }, + Constraint { + name: "ori", + degree: 2, + index: 14, + }, + Constraint { + name: "xori", + degree: 2, + index: 15, + }, + Constraint { + name: "slti", + degree: 2, + index: 16, + }, + Constraint { + name: "sltiu", + degree: 2, + index: 17, + }, + Constraint { + name: "slli", + degree: 2, + index: 18, + }, + Constraint { + name: "srli", + degree: 2, + index: 19, + }, + Constraint { + name: "srai", + degree: 2, + index: 20, + }, // Upper immediate - Constraint { name: "lui", degree: 2, index: 21 }, - Constraint { name: "auipc", degree: 2, index: 22 }, - + Constraint { + name: "lui", + degree: 2, + index: 21, + }, + Constraint { + name: "auipc", + degree: 2, + index: 22, + }, // Branches - Constraint { name: "beq", degree: 2, index: 23 }, - Constraint { name: "bne", degree: 2, index: 24 }, - Constraint { name: "blt", degree: 2, index: 25 }, - Constraint { name: "bge", degree: 2, index: 26 }, - Constraint { name: "bltu", degree: 2, index: 27 }, - Constraint { name: "bgeu", degree: 2, index: 28 }, - + Constraint { + name: "beq", + degree: 2, + index: 23, + }, + Constraint { + name: "bne", + degree: 2, + index: 24, + }, + Constraint { + name: "blt", + degree: 2, + index: 25, + }, + Constraint { + name: "bge", + degree: 2, + index: 26, + }, + Constraint { + name: "bltu", + degree: 2, + index: 27, + }, + Constraint { + name: "bgeu", + degree: 2, + index: 28, + }, // Jumps - Constraint { name: "jal", degree: 2, index: 29 }, - Constraint { name: "jalr", degree: 2, index: 30 }, - + Constraint { + name: "jal", + degree: 2, + index: 29, + }, + Constraint { + name: "jalr", + degree: 2, + index: 30, + }, // Memory (memory arg handled separately) - Constraint { name: "load_addr", degree: 2, index: 31 }, - Constraint { name: "store_addr", degree: 2, index: 32 }, - Constraint { name: "load_value", degree: 2, index: 33 }, - Constraint { name: "store_value", degree: 2, index: 34 }, - + Constraint { + name: "load_addr", + degree: 2, + index: 31, + }, + Constraint { + name: "store_addr", + degree: 2, + index: 32, + }, + Constraint { + name: "load_value", + degree: 2, + index: 33, + }, + Constraint { + name: "store_value", + degree: 2, + index: 34, + }, // M extension - Constraint { name: "mul_lo", degree: 2, index: 35 }, - Constraint { name: "mul_hi", degree: 2, index: 36 }, - Constraint { name: "div", degree: 2, index: 37 }, - Constraint { name: "rem", degree: 2, index: 38 }, + Constraint { + name: "mul_lo", + degree: 2, + index: 35, + }, + Constraint { + name: "mul_hi", + degree: 2, + index: 36, + }, + Constraint { + name: "div", + degree: 2, + index: 37, + }, + Constraint { + name: "rem", + degree: 2, + index: 38, + }, ]; - + Self { constraints } } - + /// Get all constraints. pub fn constraints(&self) -> &[Constraint] { &self.constraints } - + /// Get constraint count. pub fn num_constraints(&self) -> usize { self.constraints.len() @@ -206,12 +355,12 @@ pub struct CpuTraceRow { // Program counter pub pc: M31, pub next_pc: M31, - + // Register indices pub rd: M31, pub rs1: M31, pub rs2: M31, - + // Register values (split into limbs for overflow handling) pub rd_val_lo: M31, pub rd_val_hi: M31, @@ -219,10 +368,10 @@ pub struct CpuTraceRow { pub rs1_val_hi: M31, pub rs2_val_lo: M31, pub rs2_val_hi: M31, - + // Immediate value pub imm: M31, - + // Instruction selectors (one-hot encoded) pub is_add: M31, pub is_sub: M31, @@ -234,7 +383,7 @@ pub struct CpuTraceRow { pub is_sra: M31, pub is_slt: M31, pub is_sltu: M31, - + pub is_addi: M31, pub is_andi: M31, pub is_ori: M31, @@ -244,10 +393,10 @@ pub struct CpuTraceRow { pub is_slli: M31, pub is_srli: M31, pub is_srai: M31, - + pub is_lui: M31, pub is_auipc: M31, - + pub is_beq: M31, pub is_bne: M31, pub is_blt: M31, @@ -255,10 +404,10 @@ pub struct CpuTraceRow { pub is_bltu: M31, pub is_bgeu: M31, pub branch_taken: M31, - + pub is_jal: M31, pub is_jalr: M31, - + pub is_load: M31, pub is_store: M31, pub mem_addr: M31, @@ -272,7 +421,7 @@ pub struct CpuTraceRow { pub is_sb: M31, pub is_sh: M31, pub is_sw: M31, - + pub is_mul: M31, pub is_mulh: M31, pub is_mulhsu: M31, @@ -281,7 +430,7 @@ pub struct CpuTraceRow { pub is_divu: M31, pub is_rem: M31, pub is_remu: M31, - + // Auxiliary witness columns pub carry: M31, pub borrow: M31, @@ -290,21 +439,21 @@ pub struct CpuTraceRow { pub remainder_lo: M31, pub remainder_hi: M31, pub sb_carry: M31, - + // Comparison result (for SLT/SLTU/branches) pub lt_result: M31, pub eq_result: M31, - + // Bitwise operation bit decompositions // INPUT bit witnesses (for proper constraint verification) pub rs1_bits: [M31; 32], pub rs2_bits: [M31; 32], - pub imm_bits: [M31; 32], // For immediately variant constraints + pub imm_bits: [M31; 32], // For immediately variant constraints // OUTPUT bit witnesses pub and_bits: [M31; 32], pub xor_bits: [M31; 32], pub or_bits: [M31; 32], - + // Byte decompositions for lookup table integration (4 bytes per 32-bit value) pub rs1_bytes: [M31; 4], pub rs2_bytes: [M31; 4], @@ -315,14 +464,14 @@ pub struct CpuTraceRow { impl CpuTraceRow { /// Create a row from a slice of column values. - /// + /// /// The slice must match the order defined in `TraceColumns::to_columns`. pub fn from_slice(cols: &[M31]) -> Self { let two_16 = M31::new(1 << 16); - + // Recombine split fields let imm = cols[8] + cols[9] * two_16; - + Self { pc: cols[1], next_pc: cols[2], @@ -336,7 +485,7 @@ impl CpuTraceRow { rs2_val_lo: cols[14], rs2_val_hi: cols[15], imm, - + is_add: cols[16], is_sub: cols[17], is_and: cols[18], @@ -347,7 +496,7 @@ impl CpuTraceRow { is_sra: cols[23], is_slt: cols[24], is_sltu: cols[25], - + is_addi: cols[26], is_andi: cols[27], is_ori: cols[28], @@ -357,20 +506,20 @@ impl CpuTraceRow { is_slli: cols[32], is_srli: cols[33], is_srai: cols[34], - + is_lui: cols[35], is_auipc: cols[36], - + is_beq: cols[37], is_bne: cols[38], is_blt: cols[39], is_bge: cols[40], is_bltu: cols[41], is_bgeu: cols[42], - + is_jal: cols[43], is_jalr: cols[44], - + is_mul: cols[45], is_mulh: cols[46], is_mulhsu: cols[47], @@ -379,7 +528,7 @@ impl CpuTraceRow { is_divu: cols[50], is_rem: cols[51], is_remu: cols[52], - + is_lb: cols[53], is_lbu: cols[54], is_lh: cols[55], @@ -388,27 +537,27 @@ impl CpuTraceRow { is_sb: cols[58], is_sh: cols[59], is_sw: cols[60], - + // Derived/Combined is_load: cols[53] + cols[54] + cols[55] + cols[56] + cols[57], is_store: cols[58] + cols[59] + cols[60], - + mem_addr: cols[61] + cols[62] * two_16, mem_val_lo: cols[63], mem_val_hi: cols[64], sb_carry: cols[65], - + carry: cols[68], borrow: cols[69], quotient_lo: cols[70], quotient_hi: cols[71], remainder_lo: cols[72], remainder_hi: cols[73], - + lt_result: cols[74], eq_result: cols[75], branch_taken: cols[76], - + // Extract bit decompositions (cols 77-268: 192 total) // cols 77-108: rs1_bits[32] // cols 109-140: rs2_bits[32] @@ -417,24 +566,48 @@ impl CpuTraceRow { // cols 205-236: xor_bits[32] // cols 237-268: or_bits[32] rs1_bits: std::array::from_fn(|i| { - if cols.len() > 77 + i { cols[77 + i] } else { M31::ZERO } + if cols.len() > 77 + i { + cols[77 + i] + } else { + M31::ZERO + } }), rs2_bits: std::array::from_fn(|i| { - if cols.len() > 109 + i { cols[109 + i] } else { M31::ZERO } + if cols.len() > 109 + i { + cols[109 + i] + } else { + M31::ZERO + } }), imm_bits: std::array::from_fn(|i| { - if cols.len() > 141 + i { cols[141 + i] } else { M31::ZERO } + if cols.len() > 141 + i { + cols[141 + i] + } else { + M31::ZERO + } }), and_bits: std::array::from_fn(|i| { - if cols.len() > 173 + i { cols[173 + i] } else { M31::ZERO } + if cols.len() > 173 + i { + cols[173 + i] + } else { + M31::ZERO + } }), xor_bits: std::array::from_fn(|i| { - if cols.len() > 205 + i { cols[205 + i] } else { M31::ZERO } + if cols.len() > 205 + i { + cols[205 + i] + } else { + M31::ZERO + } }), or_bits: std::array::from_fn(|i| { - if cols.len() > 237 + i { cols[237 + i] } else { M31::ZERO } + if cols.len() > 237 + i { + cols[237 + i] + } else { + M31::ZERO + } }), - + // Extract byte decompositions (cols 269-288: 20 total) // cols 269-272: rs1_bytes[4] // cols 273-276: rs2_bytes[4] @@ -442,19 +615,39 @@ impl CpuTraceRow { // cols 281-284: or_result_bytes[4] // cols 285-288: xor_result_bytes[4] rs1_bytes: std::array::from_fn(|i| { - if cols.len() > 269 + i { cols[269 + i] } else { M31::ZERO } + if cols.len() > 269 + i { + cols[269 + i] + } else { + M31::ZERO + } }), rs2_bytes: std::array::from_fn(|i| { - if cols.len() > 273 + i { cols[273 + i] } else { M31::ZERO } + if cols.len() > 273 + i { + cols[273 + i] + } else { + M31::ZERO + } }), and_result_bytes: std::array::from_fn(|i| { - if cols.len() > 277 + i { cols[277 + i] } else { M31::ZERO } + if cols.len() > 277 + i { + cols[277 + i] + } else { + M31::ZERO + } }), or_result_bytes: std::array::from_fn(|i| { - if cols.len() > 281 + i { cols[281 + i] } else { M31::ZERO } + if cols.len() > 281 + i { + cols[281 + i] + } else { + M31::ZERO + } }), xor_result_bytes: std::array::from_fn(|i| { - if cols.len() > 285 + i { cols[285 + i] } else { M31::ZERO } + if cols.len() > 285 + i { + cols[285 + i] + } else { + M31::ZERO + } }), } } @@ -506,54 +699,54 @@ impl ConstraintEvaluator { // For now, assume pre-processing ensures x0 writes are NOPs M31::ZERO } - + /// PC increment for sequential instructions. #[inline] pub fn pc_increment(row: &CpuTraceRow) -> M31 { let four = M31::new(4); - let is_sequential = M31::ONE - - row.is_beq - row.is_bne - row.is_blt - row.is_bge - - row.is_bltu - row.is_bgeu - row.is_jal - row.is_jalr; - + let is_sequential = M31::ONE + - row.is_beq + - row.is_bne + - row.is_blt + - row.is_bge + - row.is_bltu + - row.is_bgeu + - row.is_jal + - row.is_jalr; + is_sequential * (row.next_pc - row.pc - four) } - + /// ADD: rd = rs1 + rs2. #[inline] pub fn add_constraint(row: &CpuTraceRow) -> (M31, M31) { let two_16 = M31::new(1 << 16); - + // Low limb with carry out - let c1 = row.is_add * ( - row.rd_val_lo - row.rs1_val_lo - row.rs2_val_lo + row.carry * two_16 - ); - + let c1 = + row.is_add * (row.rd_val_lo - row.rs1_val_lo - row.rs2_val_lo + row.carry * two_16); + // High limb with carry in, mod 2^16 - let c2 = row.is_add * ( - row.rd_val_hi - row.rs1_val_hi - row.rs2_val_hi - row.carry - ); - + let c2 = row.is_add * (row.rd_val_hi - row.rs1_val_hi - row.rs2_val_hi - row.carry); + (c1, c2) } - + /// SUB: rd = rs1 - rs2. #[inline] pub fn sub_constraint(row: &CpuTraceRow) -> (M31, M31) { let two_16 = M31::new(1 << 16); - + // Low limb with borrow - let c1 = row.is_sub * ( - row.rd_val_lo - row.rs1_val_lo + row.rs2_val_lo + row.borrow * two_16 - ); - + let c1 = + row.is_sub * (row.rd_val_lo - row.rs1_val_lo + row.rs2_val_lo + row.borrow * two_16); + // High limb with borrow - let c2 = row.is_sub * ( - row.rd_val_hi - row.rs1_val_hi + row.rs2_val_hi + row.borrow - ); - + let c2 = row.is_sub * (row.rd_val_hi - row.rs1_val_hi + row.rs2_val_hi + row.borrow); + (c1, c2) } - + /// AND: rd = rs1 & rs2. /// Uses bit decomposition with 4 verification steps: /// 1. rs1 = sum(rs1_bits[i] * 2^i) @@ -570,13 +763,14 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + let mut rs1_reconstructed = M31::ZERO; let mut rs2_reconstructed = M31::ZERO; let mut rd_reconstructed = M31::ZERO; let mut and_check = M31::ZERO; - - for i in 0..31 { // First 31 bits (fit in M31) + + for i in 0..31 { + // First 31 bits (fit in M31) let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; rs2_reconstructed += row.rs2_bits[i] * pow2; @@ -590,17 +784,16 @@ impl ConstraintEvaluator { rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2); rd_reconstructed += row.and_bits[31] * pow2_30 * M31::new(2); and_check += row.and_bits[31] - row.rs1_bits[31] * row.rs2_bits[31]; - + // All 4 checks in one constraint: // (rs1 reconstruction) + (rs2 reconstruction) + (AND logic) + (rd reconstruction) - row.is_and * ( - (rs1_full - rs1_reconstructed) + - (rs2_full - rs2_reconstructed) + - and_check + - (rd_full - rd_reconstructed) - ) - } - + row.is_and + * ((rs1_full - rs1_reconstructed) + + (rs2_full - rs2_reconstructed) + + and_check + + (rd_full - rd_reconstructed)) + } + /// OR: rd = rs1 | rs2. /// Uses bit decomposition: or_bit[i] = rs1_bit[i] + rs2_bit[i] - rs1_bit[i]*rs2_bit[i]. #[inline] @@ -613,12 +806,12 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + let mut rs1_reconstructed = M31::ZERO; let mut rs2_reconstructed = M31::ZERO; let mut rd_reconstructed = M31::ZERO; let mut or_check = M31::ZERO; - + for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; @@ -635,15 +828,14 @@ impl ConstraintEvaluator { rd_reconstructed += row.or_bits[31] * pow2_30 * M31::new(2); let expected_or = row.rs1_bits[31] + row.rs2_bits[31] - row.rs1_bits[31] * row.rs2_bits[31]; or_check += row.or_bits[31] - expected_or; - - row.is_or * ( - (rs1_full - rs1_reconstructed) + - (rs2_full - rs2_reconstructed) + - or_check + - (rd_full - rd_reconstructed) - ) - } - + + row.is_or + * ((rs1_full - rs1_reconstructed) + + (rs2_full - rs2_reconstructed) + + or_check + + (rd_full - rd_reconstructed)) + } + /// XOR: rd = rs1 ^ rs2. /// Uses bit decomposition: xor_bit[i] = rs1_bit[i] + rs2_bit[i] - 2*rs1_bit[i]*rs2_bit[i]. #[inline] @@ -656,19 +848,20 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + let mut rs1_reconstructed = M31::ZERO; let mut rs2_reconstructed = M31::ZERO; let mut rd_reconstructed = M31::ZERO; let mut xor_check = M31::ZERO; - + for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; rs2_reconstructed += row.rs2_bits[i] * pow2; rd_reconstructed += row.xor_bits[i] * pow2; // XOR logic: xor_bit = a + b - 2ab - let expected_xor = row.rs1_bits[i] + row.rs2_bits[i] - M31::new(2) * row.rs1_bits[i] * row.rs2_bits[i]; + let expected_xor = + row.rs1_bits[i] + row.rs2_bits[i] - M31::new(2) * row.rs1_bits[i] * row.rs2_bits[i]; xor_check += row.xor_bits[i] - expected_xor; } // Bit 31 @@ -676,17 +869,17 @@ impl ConstraintEvaluator { rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2); rd_reconstructed += row.xor_bits[31] * pow2_30 * M31::new(2); - let expected_xor = row.rs1_bits[31] + row.rs2_bits[31] - M31::new(2) * row.rs1_bits[31] * row.rs2_bits[31]; + let expected_xor = + row.rs1_bits[31] + row.rs2_bits[31] - M31::new(2) * row.rs1_bits[31] * row.rs2_bits[31]; xor_check += row.xor_bits[31] - expected_xor; - - row.is_xor * ( - (rs1_full - rs1_reconstructed) + - (rs2_full - rs2_reconstructed) + - xor_check + - (rd_full - rd_reconstructed) - ) - } - + + row.is_xor + * ((rs1_full - rs1_reconstructed) + + (rs2_full - rs2_reconstructed) + + xor_check + + (rd_full - rd_reconstructed)) + } + // ==================== LOOKUP-BASED CONSTRAINTS ==================== // These use 4 byte decomposition instead of 32 bit decomposition. // The actual lookup verification is handled by LogUp in the prover. @@ -694,7 +887,7 @@ impl ConstraintEvaluator { // 1. Byte decomposition: value = sum(bytes[i] * 256^i) // 2. Byte range: 0 <= bytes[i] < 256 (via lookup table membership) // 3. Bitwise correctness: result_bytes match operation on input bytes (via lookup) - + /// AND using lookup tables: rd = rs1 & rs2. /// Verifies byte decomposition; LogUp handles operation correctness. #[inline] @@ -707,36 +900,35 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + // Verify byte decomposition: value = b0 + b1*256 + b2*256^2 + b3*256^3 let n256 = M31::new(256); let n256_2 = M31::new(256 * 256); // Note: 256^3 = 16777216, which fits in M31 let n256_3 = M31::new(256 * 256 * 256); - - let rs1_from_bytes = row.rs1_bytes[0] - + row.rs1_bytes[1] * n256 - + row.rs1_bytes[2] * n256_2 + + let rs1_from_bytes = row.rs1_bytes[0] + + row.rs1_bytes[1] * n256 + + row.rs1_bytes[2] * n256_2 + row.rs1_bytes[3] * n256_3; - - let rs2_from_bytes = row.rs2_bytes[0] - + row.rs2_bytes[1] * n256 - + row.rs2_bytes[2] * n256_2 + + let rs2_from_bytes = row.rs2_bytes[0] + + row.rs2_bytes[1] * n256 + + row.rs2_bytes[2] * n256_2 + row.rs2_bytes[3] * n256_3; - - let rd_from_bytes = row.and_result_bytes[0] - + row.and_result_bytes[1] * n256 - + row.and_result_bytes[2] * n256_2 + + let rd_from_bytes = row.and_result_bytes[0] + + row.and_result_bytes[1] * n256 + + row.and_result_bytes[2] * n256_2 + row.and_result_bytes[3] * n256_3; - + // Constraint: all decompositions must match - row.is_and * ( - (rs1_full - rs1_from_bytes) + - (rs2_full - rs2_from_bytes) + - (rd_full - rd_from_bytes) - ) + row.is_and + * ((rs1_full - rs1_from_bytes) + + (rs2_full - rs2_from_bytes) + + (rd_full - rd_from_bytes)) } - + /// OR using lookup tables: rd = rs1 | rs2. /// Verifies byte decomposition; LogUp handles operation correctness. #[inline] @@ -749,33 +941,32 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + let n256 = M31::new(256); let n256_2 = M31::new(256 * 256); let n256_3 = M31::new(256 * 256 * 256); - - let rs1_from_bytes = row.rs1_bytes[0] - + row.rs1_bytes[1] * n256 - + row.rs1_bytes[2] * n256_2 + + let rs1_from_bytes = row.rs1_bytes[0] + + row.rs1_bytes[1] * n256 + + row.rs1_bytes[2] * n256_2 + row.rs1_bytes[3] * n256_3; - - let rs2_from_bytes = row.rs2_bytes[0] - + row.rs2_bytes[1] * n256 - + row.rs2_bytes[2] * n256_2 + + let rs2_from_bytes = row.rs2_bytes[0] + + row.rs2_bytes[1] * n256 + + row.rs2_bytes[2] * n256_2 + row.rs2_bytes[3] * n256_3; - - let rd_from_bytes = row.or_result_bytes[0] - + row.or_result_bytes[1] * n256 - + row.or_result_bytes[2] * n256_2 + + let rd_from_bytes = row.or_result_bytes[0] + + row.or_result_bytes[1] * n256 + + row.or_result_bytes[2] * n256_2 + row.or_result_bytes[3] * n256_3; - - row.is_or * ( - (rs1_full - rs1_from_bytes) + - (rs2_full - rs2_from_bytes) + - (rd_full - rd_from_bytes) - ) - } - + + row.is_or + * ((rs1_full - rs1_from_bytes) + + (rs2_full - rs2_from_bytes) + + (rd_full - rd_from_bytes)) + } + /// XOR using lookup tables: rd = rs1 ^ rs2. /// Verifies byte decomposition; LogUp handles operation correctness. #[inline] @@ -788,109 +979,108 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + let n256 = M31::new(256); let n256_2 = M31::new(256 * 256); let n256_3 = M31::new(256 * 256 * 256); - - let rs1_from_bytes = row.rs1_bytes[0] - + row.rs1_bytes[1] * n256 - + row.rs1_bytes[2] * n256_2 + + let rs1_from_bytes = row.rs1_bytes[0] + + row.rs1_bytes[1] * n256 + + row.rs1_bytes[2] * n256_2 + row.rs1_bytes[3] * n256_3; - - let rs2_from_bytes = row.rs2_bytes[0] - + row.rs2_bytes[1] * n256 - + row.rs2_bytes[2] * n256_2 + + let rs2_from_bytes = row.rs2_bytes[0] + + row.rs2_bytes[1] * n256 + + row.rs2_bytes[2] * n256_2 + row.rs2_bytes[3] * n256_3; - - let rd_from_bytes = row.xor_result_bytes[0] - + row.xor_result_bytes[1] * n256 - + row.xor_result_bytes[2] * n256_2 + + let rd_from_bytes = row.xor_result_bytes[0] + + row.xor_result_bytes[1] * n256 + + row.xor_result_bytes[2] * n256_2 + row.xor_result_bytes[3] * n256_3; - - row.is_xor * ( - (rs1_full - rs1_from_bytes) + - (rs2_full - rs2_from_bytes) + - (rd_full - rd_from_bytes) - ) - } - + + row.is_xor + * ((rs1_full - rs1_from_bytes) + + (rs2_full - rs2_from_bytes) + + (rd_full - rd_from_bytes)) + } + /// SLL: rd = rs1 << (rs2 & 0x1f). #[inline] pub fn sll_constraint(row: &CpuTraceRow) -> M31 { row.is_sll * M31::ZERO // Needs bit decomposition } - + /// SRL: rd = rs1 >> (rs2 & 0x1f) (logical). #[inline] pub fn srl_constraint(row: &CpuTraceRow) -> M31 { row.is_srl * M31::ZERO } - + /// SRA: rd = rs1 >> (rs2 & 0x1f) (arithmetic). #[inline] pub fn sra_constraint(row: &CpuTraceRow) -> M31 { row.is_sra * M31::ZERO } - + /// SLT: rd = (rs1 < rs2) ? 1 : 0 (signed). #[inline] pub fn slt_constraint(row: &CpuTraceRow) -> M31 { // rd should be 0 or 1, equal to lt_result row.is_slt * (row.rd_val_lo - row.lt_result) } - + /// SLTU: rd = (rs1 < rs2) ? 1 : 0 (unsigned). #[inline] pub fn sltu_constraint(row: &CpuTraceRow) -> M31 { row.is_sltu * (row.rd_val_lo - row.lt_result) } - + /// Signed comparison constraint: verifies lt_result for signed operations. /// Checks that lt_result correctly represents rs1 < rs2 (signed). #[inline] pub fn signed_lt_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); let _two_31 = M31::new(1u32 << 31); - + // Signed comparison: check if rs1 < rs2 treating values as signed 32-bit let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; - + // Extract sign bits (bit 31) // Use borrow witness to store sign information // borrow[0] = rs1 sign bit, borrow[1] = rs2 sign bit (packed) - + // Simplified signed comparison using subtraction with borrow // If rs1 < rs2: rs1 - rs2 < 0 (needs borrow in signed arithmetic) let diff = rs1_full - rs2_full; - + // lt_result should be 1 if difference is negative (considering sign) // Use carry witness to track sign: carry = 1 means rs1 < rs2 let selector = row.is_slt + row.is_blt + row.is_bge; - + // Constraint: lt_result = carry (verified by subtraction with sign handling) // Full implementation needs sign bit extraction and comparison logic // For now: check lt_result is binary and matches carry witness let binary_check = row.lt_result * (M31::ONE - row.lt_result); let value_check = row.lt_result - row.carry; - + selector * (binary_check + value_check + diff * M31::ZERO) // diff * 0 for degree-2 } - + /// ADDI: rd = rs1 + imm. #[inline] pub fn addi_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - + // rd = rs1 + sign_extend(imm) - row.is_addi * ( - row.rd_val_lo + row.rd_val_hi * two_16 - - row.rs1_val_lo - row.rs1_val_hi * two_16 - - row.imm - ) + row.is_addi + * (row.rd_val_lo + row.rd_val_hi * two_16 + - row.rs1_val_lo + - row.rs1_val_hi * two_16 + - row.imm) } - + /// ANDI: rd = rs1 & imm. /// Uses rs1_bits and imm_bits witnesses for proper verification. #[inline] @@ -903,12 +1093,12 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let imm_full = row.imm; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + let mut rs1_reconstructed = M31::ZERO; let mut imm_reconstructed = M31::ZERO; let mut rd_reconstructed = M31::ZERO; let mut and_check = M31::ZERO; - + for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; @@ -923,15 +1113,14 @@ impl ConstraintEvaluator { imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2); rd_reconstructed += row.and_bits[31] * pow2_30 * M31::new(2); and_check += row.and_bits[31] - row.rs1_bits[31] * row.imm_bits[31]; - - row.is_andi * ( - (rs1_full - rs1_reconstructed) + - (imm_full - imm_reconstructed) + - and_check + - (rd_full - rd_reconstructed) - ) - } - + + row.is_andi + * ((rs1_full - rs1_reconstructed) + + (imm_full - imm_reconstructed) + + and_check + + (rd_full - rd_reconstructed)) + } + /// ORI: rd = rs1 | imm. /// Uses rs1_bits and imm_bits witnesses for proper verification. #[inline] @@ -944,12 +1133,12 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let imm_full = row.imm; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + let mut rs1_reconstructed = M31::ZERO; let mut imm_reconstructed = M31::ZERO; let mut rd_reconstructed = M31::ZERO; let mut or_check = M31::ZERO; - + for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; @@ -966,15 +1155,14 @@ impl ConstraintEvaluator { rd_reconstructed += row.or_bits[31] * pow2_30 * M31::new(2); let expected_or = row.rs1_bits[31] + row.imm_bits[31] - row.rs1_bits[31] * row.imm_bits[31]; or_check += row.or_bits[31] - expected_or; - - row.is_ori * ( - (rs1_full - rs1_reconstructed) + - (imm_full - imm_reconstructed) + - or_check + - (rd_full - rd_reconstructed) - ) - } - + + row.is_ori + * ((rs1_full - rs1_reconstructed) + + (imm_full - imm_reconstructed) + + or_check + + (rd_full - rd_reconstructed)) + } + /// XORI: rd = rs1 ^ imm. /// Uses rs1_bits and imm_bits witnesses for proper verification. #[inline] @@ -987,19 +1175,20 @@ impl ConstraintEvaluator { let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let imm_full = row.imm; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + let mut rs1_reconstructed = M31::ZERO; let mut imm_reconstructed = M31::ZERO; let mut rd_reconstructed = M31::ZERO; let mut xor_check = M31::ZERO; - + for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; imm_reconstructed += row.imm_bits[i] * pow2; rd_reconstructed += row.xor_bits[i] * pow2; // XOR logic: xor_bit = a + b - 2ab - let expected_xor = row.rs1_bits[i] + row.imm_bits[i] - M31::new(2) * row.rs1_bits[i] * row.imm_bits[i]; + let expected_xor = + row.rs1_bits[i] + row.imm_bits[i] - M31::new(2) * row.rs1_bits[i] * row.imm_bits[i]; xor_check += row.xor_bits[i] - expected_xor; } // Bit 31 @@ -1007,212 +1196,202 @@ impl ConstraintEvaluator { rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2); rd_reconstructed += row.xor_bits[31] * pow2_30 * M31::new(2); - let expected_xor = row.rs1_bits[31] + row.imm_bits[31] - M31::new(2) * row.rs1_bits[31] * row.imm_bits[31]; + let expected_xor = + row.rs1_bits[31] + row.imm_bits[31] - M31::new(2) * row.rs1_bits[31] * row.imm_bits[31]; xor_check += row.xor_bits[31] - expected_xor; - - row.is_xori * ( - (rs1_full - rs1_reconstructed) + - (imm_full - imm_reconstructed) + - xor_check + - (rd_full - rd_reconstructed) - ) - } - + + row.is_xori + * ((rs1_full - rs1_reconstructed) + + (imm_full - imm_reconstructed) + + xor_check + + (rd_full - rd_reconstructed)) + } + /// SLTI: rd = (rs1 < imm) ? 1 : 0 (signed). #[inline] pub fn slti_constraint(row: &CpuTraceRow) -> M31 { row.is_slti * (row.rd_val_lo - row.lt_result) } - + /// SLTIU: rd = (rs1 < imm) ? 1 : 0 (unsigned). #[inline] pub fn sltiu_constraint(row: &CpuTraceRow) -> M31 { row.is_sltiu * (row.rd_val_lo - row.lt_result) } - + /// SLLI: rd = rs1 << imm[4:0]. #[inline] pub fn slli_constraint(row: &CpuTraceRow) -> M31 { row.is_slli * M31::ZERO // Uses bit decomposition } - + /// SRLI: rd = rs1 >> imm[4:0] (logical). #[inline] pub fn srli_constraint(row: &CpuTraceRow) -> M31 { row.is_srli * M31::ZERO } - + /// SRAI: rd = rs1 >> imm[4:0] (arithmetic). #[inline] pub fn srai_constraint(row: &CpuTraceRow) -> M31 { row.is_srai * M31::ZERO } - + /// LUI: rd = imm << 12. #[inline] pub fn lui_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); row.is_lui * (row.rd_val_lo + row.rd_val_hi * two_16 - row.imm) } - + /// AUIPC: rd = pc + (imm << 12). #[inline] pub fn auipc_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); row.is_auipc * (row.rd_val_lo + row.rd_val_hi * two_16 - row.pc - row.imm) } - + /// BEQ: branch if rs1 == rs2. #[inline] pub fn beq_constraint(row: &CpuTraceRow) -> M31 { // If branch taken: next_pc = pc + imm // If not taken: next_pc = pc + 4 let four = M31::new(4); - - let taken_constraint = row.is_beq * row.branch_taken * ( - row.next_pc - row.pc - row.imm - ); - let not_taken_constraint = row.is_beq * (M31::ONE - row.branch_taken) * ( - row.next_pc - row.pc - four - ); - + + let taken_constraint = row.is_beq * row.branch_taken * (row.next_pc - row.pc - row.imm); + let not_taken_constraint = + row.is_beq * (M31::ONE - row.branch_taken) * (row.next_pc - row.pc - four); + taken_constraint + not_taken_constraint } - + /// BEQ condition: branch_taken = (rs1 == rs2). #[inline] pub fn beq_condition(row: &CpuTraceRow) -> M31 { row.is_beq * (row.branch_taken - row.eq_result) } - + /// BNE: branch if rs1 != rs2. #[inline] pub fn bne_constraint(row: &CpuTraceRow) -> M31 { let four = M31::new(4); - - let taken_constraint = row.is_bne * row.branch_taken * ( - row.next_pc - row.pc - row.imm - ); - let not_taken_constraint = row.is_bne * (M31::ONE - row.branch_taken) * ( - row.next_pc - row.pc - four - ); - + + let taken_constraint = row.is_bne * row.branch_taken * (row.next_pc - row.pc - row.imm); + let not_taken_constraint = + row.is_bne * (M31::ONE - row.branch_taken) * (row.next_pc - row.pc - four); + taken_constraint + not_taken_constraint } - + /// BNE condition: branch_taken = (rs1 != rs2). #[inline] pub fn bne_condition(row: &CpuTraceRow) -> M31 { row.is_bne * (row.branch_taken - (M31::ONE - row.eq_result)) } - + /// BLT: branch if rs1 < rs2 (signed). #[inline] pub fn blt_constraint(row: &CpuTraceRow) -> M31 { let four = M31::new(4); - + let taken = row.is_blt * row.branch_taken * (row.next_pc - row.pc - row.imm); let not_taken = row.is_blt * (M31::ONE - row.branch_taken) * (row.next_pc - row.pc - four); - + taken + not_taken } - + /// BLT condition. #[inline] pub fn blt_condition(row: &CpuTraceRow) -> M31 { row.is_blt * (row.branch_taken - row.lt_result) } - + /// BGE: branch if rs1 >= rs2 (signed). #[inline] pub fn bge_constraint(row: &CpuTraceRow) -> M31 { let four = M31::new(4); - + let taken = row.is_bge * row.branch_taken * (row.next_pc - row.pc - row.imm); let not_taken = row.is_bge * (M31::ONE - row.branch_taken) * (row.next_pc - row.pc - four); - + taken + not_taken } - + /// BGE condition: branch_taken = NOT(rs1 < rs2). #[inline] pub fn bge_condition(row: &CpuTraceRow) -> M31 { row.is_bge * (row.branch_taken - (M31::ONE - row.lt_result)) } - + /// BLTU: branch if rs1 < rs2 (unsigned). #[inline] pub fn bltu_constraint(row: &CpuTraceRow) -> M31 { let four = M31::new(4); - + let taken = row.is_bltu * row.branch_taken * (row.next_pc - row.pc - row.imm); let not_taken = row.is_bltu * (M31::ONE - row.branch_taken) * (row.next_pc - row.pc - four); - + taken + not_taken } - + /// BLTU condition. #[inline] pub fn bltu_condition(row: &CpuTraceRow) -> M31 { row.is_bltu * (row.branch_taken - row.lt_result) } - + /// BGEU: branch if rs1 >= rs2 (unsigned). #[inline] pub fn bgeu_constraint(row: &CpuTraceRow) -> M31 { let four = M31::new(4); - + let taken = row.is_bgeu * row.branch_taken * (row.next_pc - row.pc - row.imm); let not_taken = row.is_bgeu * (M31::ONE - row.branch_taken) * (row.next_pc - row.pc - four); - + taken + not_taken } - + /// BGEU condition: branch_taken = NOT(rs1 < rs2). #[inline] pub fn bgeu_condition(row: &CpuTraceRow) -> M31 { row.is_bgeu * (row.branch_taken - (M31::ONE - row.lt_result)) } - + /// JAL: rd = pc + 4, pc = pc + imm. #[inline] pub fn jal_constraint(row: &CpuTraceRow) -> (M31, M31) { let four = M31::new(4); let two_16 = M31::new(1 << 16); - + // rd = pc + 4 - let c1 = row.is_jal * ( - row.rd_val_lo + row.rd_val_hi * two_16 - row.pc - four - ); - + let c1 = row.is_jal * (row.rd_val_lo + row.rd_val_hi * two_16 - row.pc - four); + // next_pc = pc + imm let c2 = row.is_jal * (row.next_pc - row.pc - row.imm); - + (c1, c2) } - + /// JALR: rd = pc + 4, pc = (rs1 + imm) & ~1. #[inline] pub fn jalr_constraint(row: &CpuTraceRow) -> (M31, M31) { let four = M31::new(4); let two_16 = M31::new(1 << 16); - + // rd = pc + 4 - let c1 = row.is_jalr * ( - row.rd_val_lo + row.rd_val_hi * two_16 - row.pc - four - ); - + let c1 = row.is_jalr * (row.rd_val_lo + row.rd_val_hi * two_16 - row.pc - four); + // next_pc = (rs1 + imm) & ~1 // Use carry witness to store LSB before masking: carry = (rs1 + imm) & 1 let target = row.rs1_val_lo + row.rs1_val_hi * two_16 + row.imm; - + // Constraint: next_pc = target - carry (LSB removal) // Also verify carry is binary (0 or 1) let c2 = row.is_jalr * (row.next_pc - target + row.carry); - + (c1, c2) } - + /// JALR LSB masking constraint: ensures next_pc is aligned (even). #[inline] pub fn jalr_lsb_constraint(row: &CpuTraceRow) -> M31 { @@ -1220,31 +1399,23 @@ impl ConstraintEvaluator { // This ensures carry ∈ {0, 1} row.is_jalr * row.carry * (row.carry - M31::ONE) } - + /// Load address computation: mem_addr = rs1 + imm. #[inline] pub fn load_addr_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - - row.is_load * ( - row.mem_addr - - row.rs1_val_lo - row.rs1_val_hi * two_16 - - row.imm - ) - } - + + row.is_load * (row.mem_addr - row.rs1_val_lo - row.rs1_val_hi * two_16 - row.imm) + } + /// Store address computation: mem_addr = rs1 + imm. #[inline] pub fn store_addr_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - - row.is_store * ( - row.mem_addr - - row.rs1_val_lo - row.rs1_val_hi * two_16 - - row.imm - ) - } - + + row.is_store * (row.mem_addr - row.rs1_val_lo - row.rs1_val_hi * two_16 - row.imm) + } + /// Load value consistency: rd must equal the loaded value (already sign/zero extended). #[inline] pub fn load_value_constraint(row: &CpuTraceRow) -> M31 { @@ -1256,7 +1427,7 @@ impl ConstraintEvaluator { let load_selector = row.is_lb + row.is_lbu + row.is_lh + row.is_lhu + row.is_lw; load_selector * (rd_full - mem_full) } - + /// Store value consistency: /// - SW: full 32-bit rs2 value must match mem_val. /// - SH: lower 16 bits of rs2 must match mem_val_lo, mem_val_hi must be 0. @@ -1282,36 +1453,37 @@ impl ConstraintEvaluator { sw + sh_lo + sh_hi + sb_byte + sb_hi } - + /// MUL: rd = (rs1 * rs2)[31:0]. /// Uses witness columns for the full product. /// Product witnesses: (carry, borrow) track the low 32 bits #[inline] pub fn mul_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - + // Full 64-bit multiplication with limb decomposition // rs1 = rs1_hi * 2^16 + rs1_lo // rs2 = rs2_hi * 2^16 + rs2_lo // product = rs1 * rs2 = (rs1_hi * 2^16 + rs1_lo) * (rs2_hi * 2^16 + rs2_lo) - // = rs1_lo * rs2_lo + // = rs1_lo * rs2_lo // + 2^16 * (rs1_lo * rs2_hi + rs1_hi * rs2_lo) // + 2^32 * rs1_hi * rs2_hi - + // Compute intermediate products (all degree-2) let prod_ll = row.rs1_val_lo * row.rs2_val_lo; // Low × Low let prod_lh = row.rs1_val_lo * row.rs2_val_hi; // Low × High let prod_hl = row.rs1_val_hi * row.rs2_val_lo; // High × Low let prod_hh = row.rs1_val_hi * row.rs2_val_hi; // High × High - + // Low 32 bits: prod_ll + 2^16 * (prod_lh + prod_hl) mod 2^32 // High 32 bits: prod_hh + (prod_lh + prod_hl) >> 16 + carries // Use carry witness to track overflow from middle terms - + // Constraint: rd_val = prod_ll + 2^16 * (prod_lh + prod_hl) - 2^32 * carry let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - let expected = prod_ll + two_16 * (prod_lh + prod_hl + prod_hh * two_16) - row.carry * two_16 * two_16; - + let expected = + prod_ll + two_16 * (prod_lh + prod_hl + prod_hh * two_16) - row.carry * two_16 * two_16; + row.is_mul * (rd_full - expected) } @@ -1320,110 +1492,110 @@ impl ConstraintEvaluator { #[inline] pub fn mul_hi_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - + // Full 64-bit product with sign handling // For MULH: both operands signed // For MULHU: both operands unsigned // For MULHSU: rs1 signed, rs2 unsigned - + // Compute intermediate products let _prod_ll = row.rs1_val_lo * row.rs2_val_lo; let prod_lh = row.rs1_val_lo * row.rs2_val_hi; let prod_hl = row.rs1_val_hi * row.rs2_val_lo; let prod_hh = row.rs1_val_hi * row.rs2_val_hi; - + // High 32 bits = prod_hh + (prod_lh + prod_hl + carry_from_low) >> 16 // We use carry witness for the overflow from low word // And borrow witness for sign extension corrections - + // quotient_lo/hi stores the low 32 bits (witness for verification) let _quotient_full = row.quotient_lo + row.quotient_hi * two_16; let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; - + // Constraint: high word matches computation // rd = prod_hh + (prod_lh + prod_hl) >> 16 + carry - sign_correction let mid_sum = prod_lh + prod_hl + row.carry; let expected = prod_hh + mid_sum + row.borrow; // borrow holds sign correction - + let selector = row.is_mulh + row.is_mulhsu + row.is_mulhu; selector * (rd_full - expected) } - + /// DIV: rd = rs1 / rs2 (signed). /// Constraint: rs1 = rd * rs2 + remainder #[inline] pub fn div_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - + let rs1_full = row.rs1_val_lo + row.rs1_val_hi * two_16; let rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; let quotient_full = row.quotient_lo + row.quotient_hi * two_16; let remainder_full = row.remainder_lo + row.remainder_hi * two_16; - + // Division identity: dividend = quotient * divisor + remainder // This constraint checks: rs1 = quotient * rs2 + remainder // Note: Special cases (div by zero, overflow) handled by execution layer // Divisor = 0: quotient = -1, remainder = dividend (RISC-V spec) // Overflow (INT_MIN / -1): quotient = INT_MIN, remainder = 0 (RISC-V spec) - + // We use carry witness to indicate special cases: // carry = 0: normal division // carry = 1: division by zero (quotient = -1, remainder = rs1) // carry = 2: overflow case (quotient = INT_MIN, remainder = 0) let div_selector = row.is_div + row.is_divu; - + // Normal case constraint let identity_check = rs1_full - quotient_full * rs2_full - remainder_full; - + div_selector * identity_check } - + /// REM: rd = rs1 % rs2 (signed). /// Constraint: rd = remainder from division #[inline] pub fn rem_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - + let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; let remainder_full = row.remainder_lo + row.remainder_hi * two_16; - + // REM returns the remainder let rem_selector = row.is_rem + row.is_remu; rem_selector * (rd_full - remainder_full) } - + /// DIVU/DIV quotient constraint: quotient stored in rd for DIV instructions. #[inline] pub fn div_quotient_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - + let rd_full = row.rd_val_lo + row.rd_val_hi * two_16; let quotient_full = row.quotient_lo + row.quotient_hi * two_16; - + let div_selector = row.is_div + row.is_divu; div_selector * (rd_full - quotient_full) } - + /// Range constraint: ensure remainder < divisor (absolute value). /// For division: |remainder| < |divisor| #[inline] pub fn div_remainder_range_constraint(row: &CpuTraceRow) -> M31 { let two_16 = M31::new(1 << 16); - + let _rs2_full = row.rs2_val_lo + row.rs2_val_hi * two_16; let _remainder_full = row.remainder_lo + row.remainder_hi * two_16; - + // Range check: 0 <= remainder < |divisor| // For unsigned: 0 <= remainder < divisor // For signed: |remainder| < |divisor|, sign(remainder) = sign(dividend) - + // Simplified constraint using subtraction // When divisor != 0, verify: remainder < divisor // This is checked by ensuring (divisor - remainder) is non-negative // In the field, we assume prover provides correct witnesses - + let div_selector = row.is_div + row.is_divu + row.is_rem + row.is_remu; - + // Basic check: when remainder = 0 or remainder < divisor in correct execution // Full soundness requires lookup tables or decomposition // For now: check that if divisor is non-zero, identity holds (checked elsewhere) @@ -1431,21 +1603,21 @@ impl ConstraintEvaluator { // 1. Lookup tables for 32-bit range bounds // 2. Decomposition into limbs with bit checks // 3. Comparison circuit with witness - + // Simplified: return zero (constraint satisfied when witnesses correct) // Alternative: check borrow witness binary: borrow * (borrow - 1) = 0 let borrow_binary = row.borrow * (row.borrow - M31::ONE); - + div_selector * borrow_binary } - + /// Range constraint for limb values: ensure all limbs fit in 16 bits. /// Each limb must satisfy: limb < 2^16 #[inline] pub fn limb_range_constraint(row: &CpuTraceRow) -> M31 { // Range check: verify all limbs are in [0, 2^16) // Full implementation requires lookup tables or bit decomposition - // + // // For degree-2 constraint, we use auxiliary witness sb_carry to verify bounds: // For each limb L, verify: L + sb_carry * 2^16 < 2 * 2^16 // This forces 0 <= L < 2^16 when sb_carry ∈ {0, 1} @@ -1457,106 +1629,106 @@ impl ConstraintEvaluator { // // Current: placeholder that assumes limbs are correctly generated // The prover must ensure limbs are valid or proofs will fail - + let two_16 = M31::new(1 << 16); - + // Check a subset of critical limbs for demonstration // Real implementation checks all limbs via lookup argument let check1 = (row.rd_val_lo - two_16) * row.sb_carry; let check2 = (row.rs1_val_lo - two_16) * row.sb_carry; - + // Binary witness check let binary = row.sb_carry * (row.sb_carry - M31::ONE); - + check1 + check2 + binary } - + /// Evaluate all constraints and return vector of constraint values. pub fn evaluate_all(row: &CpuTraceRow) -> Vec { - let mut constraints = Vec::new(); - - constraints.push(ConstraintEvaluator::x0_zero(row)); - constraints.push(ConstraintEvaluator::pc_increment(row)); - - let (add_c1, add_c2) = ConstraintEvaluator::add_constraint(row); - constraints.push(add_c1); - constraints.push(add_c2); - - let (sub_c1, sub_c2) = ConstraintEvaluator::sub_constraint(row); - constraints.push(sub_c1); - constraints.push(sub_c2); - - constraints.push(ConstraintEvaluator::and_constraint(row)); - constraints.push(ConstraintEvaluator::or_constraint(row)); - constraints.push(ConstraintEvaluator::xor_constraint(row)); - constraints.push(ConstraintEvaluator::sll_constraint(row)); - constraints.push(ConstraintEvaluator::srl_constraint(row)); - constraints.push(ConstraintEvaluator::sra_constraint(row)); - constraints.push(ConstraintEvaluator::slt_constraint(row)); - constraints.push(ConstraintEvaluator::sltu_constraint(row)); - constraints.push(ConstraintEvaluator::signed_lt_constraint(row)); - - constraints.push(ConstraintEvaluator::addi_constraint(row)); - constraints.push(ConstraintEvaluator::andi_constraint(row)); - constraints.push(ConstraintEvaluator::ori_constraint(row)); - constraints.push(ConstraintEvaluator::xori_constraint(row)); - constraints.push(ConstraintEvaluator::slti_constraint(row)); - constraints.push(ConstraintEvaluator::sltiu_constraint(row)); - constraints.push(ConstraintEvaluator::slli_constraint(row)); - constraints.push(ConstraintEvaluator::srli_constraint(row)); - constraints.push(ConstraintEvaluator::srai_constraint(row)); - - constraints.push(ConstraintEvaluator::lui_constraint(row)); - constraints.push(ConstraintEvaluator::auipc_constraint(row)); - - constraints.push(ConstraintEvaluator::beq_constraint(row)); - constraints.push(ConstraintEvaluator::beq_condition(row)); - constraints.push(ConstraintEvaluator::bne_constraint(row)); - constraints.push(ConstraintEvaluator::bne_condition(row)); - constraints.push(ConstraintEvaluator::blt_constraint(row)); - constraints.push(ConstraintEvaluator::blt_condition(row)); - constraints.push(ConstraintEvaluator::bge_constraint(row)); - constraints.push(ConstraintEvaluator::bge_condition(row)); - constraints.push(ConstraintEvaluator::bltu_constraint(row)); - constraints.push(ConstraintEvaluator::bltu_condition(row)); - constraints.push(ConstraintEvaluator::bgeu_constraint(row)); - constraints.push(ConstraintEvaluator::bgeu_condition(row)); - - let (jal_c1, jal_c2) = ConstraintEvaluator::jal_constraint(row); - constraints.push(jal_c1); - constraints.push(jal_c2); - - let (jalr_c1, jalr_c2) = ConstraintEvaluator::jalr_constraint(row); - constraints.push(jalr_c1); - constraints.push(jalr_c2); - constraints.push(ConstraintEvaluator::jalr_lsb_constraint(row)); - - constraints.push(ConstraintEvaluator::load_addr_constraint(row)); - constraints.push(ConstraintEvaluator::store_addr_constraint(row)); - constraints.push(ConstraintEvaluator::load_value_constraint(row)); - constraints.push(ConstraintEvaluator::store_value_constraint(row)); - - constraints.push(ConstraintEvaluator::mul_constraint(row)); - constraints.push(ConstraintEvaluator::mul_hi_constraint(row)); - constraints.push(ConstraintEvaluator::div_constraint(row)); - constraints.push(ConstraintEvaluator::div_quotient_constraint(row)); - constraints.push(ConstraintEvaluator::rem_constraint(row)); - constraints.push(ConstraintEvaluator::div_remainder_range_constraint(row)); - constraints.push(ConstraintEvaluator::limb_range_constraint(row)); - - constraints -} + let mut constraints = Vec::new(); + + constraints.push(ConstraintEvaluator::x0_zero(row)); + constraints.push(ConstraintEvaluator::pc_increment(row)); + + let (add_c1, add_c2) = ConstraintEvaluator::add_constraint(row); + constraints.push(add_c1); + constraints.push(add_c2); + + let (sub_c1, sub_c2) = ConstraintEvaluator::sub_constraint(row); + constraints.push(sub_c1); + constraints.push(sub_c2); + + constraints.push(ConstraintEvaluator::and_constraint(row)); + constraints.push(ConstraintEvaluator::or_constraint(row)); + constraints.push(ConstraintEvaluator::xor_constraint(row)); + constraints.push(ConstraintEvaluator::sll_constraint(row)); + constraints.push(ConstraintEvaluator::srl_constraint(row)); + constraints.push(ConstraintEvaluator::sra_constraint(row)); + constraints.push(ConstraintEvaluator::slt_constraint(row)); + constraints.push(ConstraintEvaluator::sltu_constraint(row)); + constraints.push(ConstraintEvaluator::signed_lt_constraint(row)); + + constraints.push(ConstraintEvaluator::addi_constraint(row)); + constraints.push(ConstraintEvaluator::andi_constraint(row)); + constraints.push(ConstraintEvaluator::ori_constraint(row)); + constraints.push(ConstraintEvaluator::xori_constraint(row)); + constraints.push(ConstraintEvaluator::slti_constraint(row)); + constraints.push(ConstraintEvaluator::sltiu_constraint(row)); + constraints.push(ConstraintEvaluator::slli_constraint(row)); + constraints.push(ConstraintEvaluator::srli_constraint(row)); + constraints.push(ConstraintEvaluator::srai_constraint(row)); + + constraints.push(ConstraintEvaluator::lui_constraint(row)); + constraints.push(ConstraintEvaluator::auipc_constraint(row)); + + constraints.push(ConstraintEvaluator::beq_constraint(row)); + constraints.push(ConstraintEvaluator::beq_condition(row)); + constraints.push(ConstraintEvaluator::bne_constraint(row)); + constraints.push(ConstraintEvaluator::bne_condition(row)); + constraints.push(ConstraintEvaluator::blt_constraint(row)); + constraints.push(ConstraintEvaluator::blt_condition(row)); + constraints.push(ConstraintEvaluator::bge_constraint(row)); + constraints.push(ConstraintEvaluator::bge_condition(row)); + constraints.push(ConstraintEvaluator::bltu_constraint(row)); + constraints.push(ConstraintEvaluator::bltu_condition(row)); + constraints.push(ConstraintEvaluator::bgeu_constraint(row)); + constraints.push(ConstraintEvaluator::bgeu_condition(row)); + + let (jal_c1, jal_c2) = ConstraintEvaluator::jal_constraint(row); + constraints.push(jal_c1); + constraints.push(jal_c2); + + let (jalr_c1, jalr_c2) = ConstraintEvaluator::jalr_constraint(row); + constraints.push(jalr_c1); + constraints.push(jalr_c2); + constraints.push(ConstraintEvaluator::jalr_lsb_constraint(row)); + + constraints.push(ConstraintEvaluator::load_addr_constraint(row)); + constraints.push(ConstraintEvaluator::store_addr_constraint(row)); + constraints.push(ConstraintEvaluator::load_value_constraint(row)); + constraints.push(ConstraintEvaluator::store_value_constraint(row)); + + constraints.push(ConstraintEvaluator::mul_constraint(row)); + constraints.push(ConstraintEvaluator::mul_hi_constraint(row)); + constraints.push(ConstraintEvaluator::div_constraint(row)); + constraints.push(ConstraintEvaluator::div_quotient_constraint(row)); + constraints.push(ConstraintEvaluator::rem_constraint(row)); + constraints.push(ConstraintEvaluator::div_remainder_range_constraint(row)); + constraints.push(ConstraintEvaluator::limb_range_constraint(row)); + + constraints + } } #[cfg(test)] mod tests { use super::*; - + #[test] fn test_rv32im_air_creation() { let air = Rv32imAir::new(); assert!(air.num_constraints() > 30); - + // Check constraint names let names: Vec<_> = air.constraints().iter().map(|c| c.name).collect(); assert!(names.contains(&"add")); @@ -1565,74 +1737,74 @@ mod tests { assert!(names.contains(&"jal")); assert!(names.contains(&"mul_lo")); } - + #[test] fn test_add_constraint() { let mut row = CpuTraceRow::default(); - + // ADD: 5 + 3 = 8 row.is_add = M31::ONE; row.rs1_val_lo = M31::new(5); row.rs2_val_lo = M31::new(3); row.rd_val_lo = M31::new(8); row.carry = M31::ZERO; - + let (c1, c2) = ConstraintEvaluator::add_constraint(&row); assert_eq!(c1, M31::ZERO); assert_eq!(c2, M31::ZERO); } - + #[test] fn test_add_with_carry() { let mut row = CpuTraceRow::default(); - + // ADD causing carry: 0xFFFF + 1 = 0x10000 row.is_add = M31::ONE; row.rs1_val_lo = M31::new(0xFFFF); row.rs2_val_lo = M31::new(1); - row.rd_val_lo = M31::ZERO; // Low part is 0 - row.rd_val_hi = M31::ONE; // High part is 1 + row.rd_val_lo = M31::ZERO; // Low part is 0 + row.rd_val_hi = M31::ONE; // High part is 1 row.carry = M31::ONE; - + let (c1, c2) = ConstraintEvaluator::add_constraint(&row); assert_eq!(c1, M31::ZERO); assert_eq!(c2, M31::ZERO); } - + #[test] fn test_sub_constraint() { let mut row = CpuTraceRow::default(); - + // SUB: 10 - 3 = 7 row.is_sub = M31::ONE; row.rs1_val_lo = M31::new(10); row.rs2_val_lo = M31::new(3); row.rd_val_lo = M31::new(7); row.borrow = M31::ZERO; - + let (c1, c2) = ConstraintEvaluator::sub_constraint(&row); assert_eq!(c1, M31::ZERO); assert_eq!(c2, M31::ZERO); } - + #[test] fn test_lui_constraint() { let mut row = CpuTraceRow::default(); - + // LUI: rd = imm (upper 20 bits) row.is_lui = M31::ONE; row.imm = M31::new(0x12345000); row.rd_val_lo = M31::new(0x5000); row.rd_val_hi = M31::new(0x1234); - + let c = ConstraintEvaluator::lui_constraint(&row); assert_eq!(c, M31::ZERO); } - + #[test] fn test_beq_taken() { let mut row = CpuTraceRow::default(); - + // BEQ taken: pc = 100, imm = 20, next_pc should be 120 row.is_beq = M31::ONE; row.pc = M31::new(100); @@ -1640,18 +1812,18 @@ mod tests { row.next_pc = M31::new(120); row.branch_taken = M31::ONE; row.eq_result = M31::ONE; - + let c = ConstraintEvaluator::beq_constraint(&row); assert_eq!(c, M31::ZERO); - + let cond = ConstraintEvaluator::beq_condition(&row); assert_eq!(cond, M31::ZERO); } - + #[test] fn test_beq_not_taken() { let mut row = CpuTraceRow::default(); - + // BEQ not taken: pc = 100, next_pc = 104 row.is_beq = M31::ONE; row.pc = M31::new(100); @@ -1659,18 +1831,18 @@ mod tests { row.next_pc = M31::new(104); row.branch_taken = M31::ZERO; row.eq_result = M31::ZERO; - + let c = ConstraintEvaluator::beq_constraint(&row); assert_eq!(c, M31::ZERO); - + let cond = ConstraintEvaluator::beq_condition(&row); assert_eq!(cond, M31::ZERO); } - + #[test] fn test_jal_constraint() { let mut row = CpuTraceRow::default(); - + // JAL: pc = 100, imm = 50 // rd = 104, next_pc = 150 row.is_jal = M31::ONE; @@ -1679,49 +1851,49 @@ mod tests { row.next_pc = M31::new(150); row.rd_val_lo = M31::new(104); row.rd_val_hi = M31::ZERO; - + let (c1, c2) = ConstraintEvaluator::jal_constraint(&row); assert_eq!(c1, M31::ZERO); assert_eq!(c2, M31::ZERO); } - + #[test] fn test_pc_increment() { let mut row = CpuTraceRow::default(); - + // Sequential instruction: pc = 100, next_pc = 104 row.pc = M31::new(100); row.next_pc = M31::new(104); - + let c = ConstraintEvaluator::pc_increment(&row); assert_eq!(c, M31::ZERO); } - + #[test] fn test_evaluate_all() { let row = CpuTraceRow::default(); let constraints = ConstraintEvaluator::evaluate_all(&row); - + // Should return all constraints assert!(constraints.len() > 20); - + // Default row (all zeros) should satisfy most selector-guarded constraints for value in &constraints { let _ = value; } } - + #[test] fn test_load_addr() { let mut row = CpuTraceRow::default(); - + // LW: addr = rs1 + imm = 0x1000 + 0x10 = 0x1010 row.is_load = M31::ONE; row.rs1_val_lo = M31::new(0x1000); row.rs1_val_hi = M31::ZERO; row.imm = M31::new(0x10); row.mem_addr = M31::new(0x1010); - + let c = ConstraintEvaluator::load_addr_constraint(&row); assert_eq!(c, M31::ZERO); } @@ -1861,24 +2033,24 @@ mod tests { row.next_pc = M31::new(0x1004); let constraints = ConstraintEvaluator::evaluate_all(&row); - + // Should have 40+ constraints now (including new range constraints) assert!(constraints.len() >= 47); - + // Most constraints should be zero for correct execution let non_zero = constraints.iter().filter(|c| **c != M31::ZERO).count(); - + // Only a few constraints should be non-zero (for inactive instructions) assert!(non_zero < constraints.len()); } - + #[test] fn test_and_constraint_lookup() { let mut row = CpuTraceRow::default(); - + // AND: 0x12345678 & 0x0F0F0F0F = 0x02040608 row.is_and = M31::ONE; - + // rs1 = 0x12345678 row.rs1_val_lo = M31::new(0x5678); row.rs1_val_hi = M31::new(0x1234); @@ -1886,7 +2058,7 @@ mod tests { row.rs1_bytes[1] = M31::new(0x56); row.rs1_bytes[2] = M31::new(0x34); row.rs1_bytes[3] = M31::new(0x12); - + // rs2 = 0x0F0F0F0F row.rs2_val_lo = M31::new(0x0F0F); row.rs2_val_hi = M31::new(0x0F0F); @@ -1894,26 +2066,26 @@ mod tests { row.rs2_bytes[1] = M31::new(0x0F); row.rs2_bytes[2] = M31::new(0x0F); row.rs2_bytes[3] = M31::new(0x0F); - + // Result = 0x02040608 row.rd_val_lo = M31::new(0x0608); row.rd_val_hi = M31::new(0x0204); - row.and_result_bytes[0] = M31::new(0x08); // 0x78 & 0x0F = 0x08 - row.and_result_bytes[1] = M31::new(0x06); // 0x56 & 0x0F = 0x06 - row.and_result_bytes[2] = M31::new(0x04); // 0x34 & 0x0F = 0x04 - row.and_result_bytes[3] = M31::new(0x02); // 0x12 & 0x0F = 0x02 - + row.and_result_bytes[0] = M31::new(0x08); // 0x78 & 0x0F = 0x08 + row.and_result_bytes[1] = M31::new(0x06); // 0x56 & 0x0F = 0x06 + row.and_result_bytes[2] = M31::new(0x04); // 0x34 & 0x0F = 0x04 + row.and_result_bytes[3] = M31::new(0x02); // 0x12 & 0x0F = 0x02 + let c = ConstraintEvaluator::and_constraint_lookup(&row); assert_eq!(c, M31::ZERO, "Lookup AND constraint should be satisfied"); } - + #[test] fn test_or_constraint_lookup() { let mut row = CpuTraceRow::default(); - + // OR: 0x12000034 | 0x00560078 = 0x125600BC row.is_or = M31::ONE; - + // rs1 = 0x12000034 row.rs1_val_lo = M31::new(0x0034); row.rs1_val_hi = M31::new(0x1200); @@ -1921,7 +2093,7 @@ mod tests { row.rs1_bytes[1] = M31::new(0x00); row.rs1_bytes[2] = M31::new(0x00); row.rs1_bytes[3] = M31::new(0x12); - + // rs2 = 0x00560078 row.rs2_val_lo = M31::new(0x0078); row.rs2_val_hi = M31::new(0x0056); @@ -1929,27 +2101,27 @@ mod tests { row.rs2_bytes[1] = M31::new(0x00); row.rs2_bytes[2] = M31::new(0x56); row.rs2_bytes[3] = M31::new(0x00); - + // Result = 0x125600BC (0x34 | 0x78 = 0x7C, but let's use correct values) // Actually: 0x34 | 0x78 = 0x7C, 0x00 | 0x00 = 0x00, 0x00 | 0x56 = 0x56, 0x12 | 0x00 = 0x12 row.rd_val_lo = M31::new(0x007C); row.rd_val_hi = M31::new(0x1256); - row.or_result_bytes[0] = M31::new(0x7C); // 0x34 | 0x78 - row.or_result_bytes[1] = M31::new(0x00); // 0x00 | 0x00 - row.or_result_bytes[2] = M31::new(0x56); // 0x00 | 0x56 - row.or_result_bytes[3] = M31::new(0x12); // 0x12 | 0x00 - + row.or_result_bytes[0] = M31::new(0x7C); // 0x34 | 0x78 + row.or_result_bytes[1] = M31::new(0x00); // 0x00 | 0x00 + row.or_result_bytes[2] = M31::new(0x56); // 0x00 | 0x56 + row.or_result_bytes[3] = M31::new(0x12); // 0x12 | 0x00 + let c = ConstraintEvaluator::or_constraint_lookup(&row); assert_eq!(c, M31::ZERO, "Lookup OR constraint should be satisfied"); } - + #[test] fn test_xor_constraint_lookup() { let mut row = CpuTraceRow::default(); - + // XOR: 0xAAAAAAAA ^ 0x55555555 = 0xFFFFFFFF row.is_xor = M31::ONE; - + // rs1 = 0xAAAAAAAA row.rs1_val_lo = M31::new(0xAAAA); row.rs1_val_hi = M31::new(0xAAAA); @@ -1957,7 +2129,7 @@ mod tests { row.rs1_bytes[1] = M31::new(0xAA); row.rs1_bytes[2] = M31::new(0xAA); row.rs1_bytes[3] = M31::new(0xAA); - + // rs2 = 0x55555555 row.rs2_val_lo = M31::new(0x5555); row.rs2_val_hi = M31::new(0x5555); @@ -1965,15 +2137,15 @@ mod tests { row.rs2_bytes[1] = M31::new(0x55); row.rs2_bytes[2] = M31::new(0x55); row.rs2_bytes[3] = M31::new(0x55); - + // Result = 0xFFFFFFFF row.rd_val_lo = M31::new(0xFFFF); row.rd_val_hi = M31::new(0xFFFF); - row.xor_result_bytes[0] = M31::new(0xFF); // 0xAA ^ 0x55 = 0xFF + row.xor_result_bytes[0] = M31::new(0xFF); // 0xAA ^ 0x55 = 0xFF row.xor_result_bytes[1] = M31::new(0xFF); row.xor_result_bytes[2] = M31::new(0xFF); row.xor_result_bytes[3] = M31::new(0xFF); - + let c = ConstraintEvaluator::xor_constraint_lookup(&row); assert_eq!(c, M31::ZERO, "Lookup XOR constraint should be satisfied"); } diff --git a/crates/executor/benches/syscall_bench.rs b/crates/executor/benches/syscall_bench.rs index ab5fdc6..3e59495 100644 --- a/crates/executor/benches/syscall_bench.rs +++ b/crates/executor/benches/syscall_bench.rs @@ -2,7 +2,7 @@ //! //! Run with: cargo bench -p zp1-executor --bench syscall_bench -use criterion::{black_box, criterion_group, criterion_main, Criterion, BenchmarkId}; +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; use zp1_executor::Cpu; // Syscall numbers @@ -20,19 +20,20 @@ const BLAKE2B_SYSCALL: u32 = 0x1005; fn setup_cpu_with_program(syscall_num: u32) -> Cpu { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + // Create program: ecall, then exit let program: Vec = vec![ - 0x00000073, // ecall (crypto operation) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (crypto operation) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - - let program_bytes: Vec = program.iter() + + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + cpu.set_reg(17, syscall_num); cpu } @@ -43,29 +44,29 @@ fn setup_cpu_with_program(syscall_num: u32) -> Cpu { fn bench_keccak256_syscall(c: &mut Criterion) { let mut group = c.benchmark_group("Syscall-Keccak256"); - + for size in [32, 64, 128, 256, 512, 1024].iter() { let message = vec![0x42u8; *size]; - + group.bench_with_input( BenchmarkId::from_parameter(format!("{}B", size)), size, |b, _| { b.iter(|| { let mut cpu = setup_cpu_with_program(KECCAK256_SYSCALL); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + cpu.set_reg(10, input_ptr); cpu.set_reg(11, message.len() as u32); cpu.set_reg(12, output_ptr); - + // Execute syscall for _ in 0..5 { if cpu.pc == 4 { @@ -73,13 +74,13 @@ fn bench_keccak256_syscall(c: &mut Criterion) { } let _ = cpu.step(); } - + black_box(cpu) }) }, ); } - + group.finish(); } @@ -89,41 +90,41 @@ fn bench_keccak256_syscall(c: &mut Criterion) { fn bench_sha256_syscall(c: &mut Criterion) { let mut group = c.benchmark_group("Syscall-SHA256"); - + for size in [32, 64, 128, 256, 512, 1024].iter() { let message = vec![0x42u8; *size]; - + group.bench_with_input( BenchmarkId::from_parameter(format!("{}B", size)), size, |b, _| { b.iter(|| { let mut cpu = setup_cpu_with_program(SHA256_SYSCALL); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + cpu.set_reg(10, input_ptr); cpu.set_reg(11, message.len() as u32); cpu.set_reg(12, output_ptr); - + for _ in 0..5 { if cpu.pc == 4 { break; } let _ = cpu.step(); } - + black_box(cpu) }) }, ); } - + group.finish(); } @@ -133,35 +134,35 @@ fn bench_sha256_syscall(c: &mut Criterion) { fn bench_ripemd160_syscall(c: &mut Criterion) { let mut group = c.benchmark_group("Syscall-RIPEMD160"); - + let message = vec![0x42u8; 256]; - + group.bench_function("256B", |b| { b.iter(|| { let mut cpu = setup_cpu_with_program(RIPEMD160_SYSCALL); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + cpu.set_reg(10, input_ptr); cpu.set_reg(11, message.len() as u32); cpu.set_reg(12, output_ptr); - + for _ in 0..5 { if cpu.pc == 4 { break; } let _ = cpu.step(); } - + black_box(cpu) }) }); - + group.finish(); } @@ -171,41 +172,41 @@ fn bench_ripemd160_syscall(c: &mut Criterion) { fn bench_blake2b_syscall(c: &mut Criterion) { let mut group = c.benchmark_group("Syscall-Blake2b"); - + for size in [32, 64, 128, 256, 512, 1024].iter() { let message = vec![0x42u8; *size]; - + group.bench_with_input( BenchmarkId::from_parameter(format!("{}B", size)), size, |b, _| { b.iter(|| { let mut cpu = setup_cpu_with_program(BLAKE2B_SYSCALL); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + cpu.set_reg(10, input_ptr); cpu.set_reg(11, message.len() as u32); cpu.set_reg(12, output_ptr); - + for _ in 0..5 { if cpu.pc == 4 { break; } let _ = cpu.step(); } - + black_box(cpu) }) }, ); } - + group.finish(); } @@ -215,49 +216,49 @@ fn bench_blake2b_syscall(c: &mut Criterion) { fn bench_ecrecover_syscall(c: &mut Criterion) { let mut group = c.benchmark_group("Syscall-ECRECOVER"); - + let hash: [u8; 32] = [0x47; 32]; let v = 28u8; let r: [u8; 32] = [0xb9; 32]; let s: [u8; 32] = [0x3c; 32]; - + group.bench_function("signature_recovery", |b| { b.iter(|| { let mut cpu = setup_cpu_with_program(ECRECOVER_SYSCALL); - + let hash_ptr = 0x1000; let output_ptr = 0x2000; - + // Write inputs to memory for (i, &byte) in hash.iter().enumerate() { cpu.memory.write_byte(hash_ptr + i as u32, byte).unwrap(); } - + let r_ptr = hash_ptr + 32; for (i, &byte) in r.iter().enumerate() { cpu.memory.write_byte(r_ptr + i as u32, byte).unwrap(); } - + let s_ptr = r_ptr + 32; for (i, &byte) in s.iter().enumerate() { cpu.memory.write_byte(s_ptr + i as u32, byte).unwrap(); } - + cpu.set_reg(10, hash_ptr); cpu.set_reg(11, (v as u32) | ((r_ptr) << 8)); cpu.set_reg(12, (s_ptr) | ((output_ptr) << 16)); - + for _ in 0..5 { if cpu.pc == 4 { break; } let _ = cpu.step(); } - + black_box(cpu) }) }); - + group.finish(); } @@ -267,39 +268,39 @@ fn bench_ecrecover_syscall(c: &mut Criterion) { fn bench_modexp_syscall(c: &mut Criterion) { let mut group = c.benchmark_group("Syscall-MODEXP"); - + group.bench_function("small_exponent", |b| { b.iter(|| { let mut cpu = setup_cpu_with_program(MODEXP_SYSCALL); - + let base_ptr = 0x1000; let exp_ptr = 0x1020; let mod_ptr = 0x1040; let result_ptr = 0x2000; - + // Write base = 2 cpu.memory.write_byte(base_ptr, 2).unwrap(); // Write exp = 3 cpu.memory.write_byte(exp_ptr, 3).unwrap(); // Write mod = 5 cpu.memory.write_byte(mod_ptr, 5).unwrap(); - + cpu.set_reg(10, base_ptr); cpu.set_reg(11, exp_ptr); cpu.set_reg(12, mod_ptr); cpu.set_reg(13, result_ptr); - + for _ in 0..5 { if cpu.pc == 4 { break; } let _ = cpu.step(); } - + black_box(cpu) }) }); - + group.finish(); } @@ -309,50 +310,50 @@ fn bench_modexp_syscall(c: &mut Criterion) { fn bench_bitcoin_address_workflow(c: &mut Criterion) { let mut group = c.benchmark_group("Workflow-Bitcoin-Address"); - + let pubkey = [0x04u8; 65]; - + group.bench_function("sha256_then_ripemd160", |b| { b.iter(|| { // SHA-256 syscall let mut cpu1 = setup_cpu_with_program(SHA256_SYSCALL); let input_ptr = 0x1000; let sha_out = 0x2000; - + for (i, &byte) in pubkey.iter().enumerate() { cpu1.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + cpu1.set_reg(10, input_ptr); cpu1.set_reg(11, 65); cpu1.set_reg(12, sha_out); - + for _ in 0..5 { if cpu1.pc == 4 { break; } let _ = cpu1.step(); } - + // RIPEMD-160 syscall let mut cpu2 = setup_cpu_with_program(RIPEMD160_SYSCALL); let ripemd_out = 0x3000; - + cpu2.set_reg(10, sha_out); cpu2.set_reg(11, 32); cpu2.set_reg(12, ripemd_out); - + for _ in 0..5 { if cpu2.pc == 4 { break; } let _ = cpu2.step(); } - + black_box((cpu1, cpu2)) }) }); - + group.finish(); } diff --git a/crates/executor/src/cpu.rs b/crates/executor/src/cpu.rs index b714fd9..183ddb3 100644 --- a/crates/executor/src/cpu.rs +++ b/crates/executor/src/cpu.rs @@ -84,8 +84,10 @@ //! The trace captures every CPU state transition and is used by the prover //! to generate a zero-knowledge proof of correct execution. -use crate::decode::{opcode, DecodedInstr, branch_funct3, load_funct3, store_funct3, - op_imm_funct3, op_funct3, system_funct3, funct7}; +use crate::decode::{ + branch_funct3, funct7, load_funct3, op_funct3, op_imm_funct3, opcode, store_funct3, + system_funct3, DecodedInstr, +}; use crate::error::ExecutorError; use crate::memory::Memory; use crate::trace::{ExecutionTrace, InstrFlags, MemOp, TraceRow}; @@ -239,7 +241,12 @@ impl Cpu { branch_funct3::BGE => (rs1_val as i32) >= (rs2_val as i32), branch_funct3::BLTU => rs1_val < rs2_val, branch_funct3::BGEU => rs1_val >= rs2_val, - _ => return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }), + _ => { + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }) + } }; if taken { @@ -257,13 +264,21 @@ impl Cpu { // LB: Load Byte (sign-extended) let val = self.memory.read_u8(addr)?; rd_val = (val as i8) as i32 as u32; - mem_op = MemOp::LoadByte { addr, value: val, signed: true }; + mem_op = MemOp::LoadByte { + addr, + value: val, + signed: true, + }; } load_funct3::LH => { // LH: Load Halfword (sign-extended) let val = self.memory.read_u16(addr)?; rd_val = (val as i16) as i32 as u32; - mem_op = MemOp::LoadHalf { addr, value: val, signed: true }; + mem_op = MemOp::LoadHalf { + addr, + value: val, + signed: true, + }; } load_funct3::LW => { // LW: Load Word @@ -275,15 +290,28 @@ impl Cpu { // LBU: Load Byte Unsigned let val = self.memory.read_u8(addr)?; rd_val = val as u32; - mem_op = MemOp::LoadByte { addr, value: val, signed: false }; + mem_op = MemOp::LoadByte { + addr, + value: val, + signed: false, + }; } load_funct3::LHU => { // LHU: Load Halfword Unsigned let val = self.memory.read_u16(addr)?; rd_val = val as u32; - mem_op = MemOp::LoadHalf { addr, value: val, signed: false }; + mem_op = MemOp::LoadHalf { + addr, + value: val, + signed: false, + }; + } + _ => { + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }) } - _ => return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }), } } @@ -311,7 +339,12 @@ impl Cpu { self.memory.write_u32(addr, val)?; mem_op = MemOp::StoreWord { addr, value: val }; } - _ => return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }), + _ => { + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }) + } } } @@ -361,7 +394,12 @@ impl Cpu { rs1_val >> shamt } } - _ => return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }), + _ => { + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }) + } }; } @@ -451,7 +489,12 @@ impl Cpu { rs1_val % rs2_val }; } - _ => return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }), + _ => { + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }) + } } } else { // ========== Base RV32I OP ========== @@ -497,7 +540,12 @@ impl Cpu { // AND rs1_val & rs2_val } - _ => return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }), + _ => { + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }) + } }; } } @@ -512,7 +560,7 @@ impl Cpu { // ECALL - Environment Call flags.is_ecall = true; let syscall_id = self.get_reg(17); // a7 register - + // Handle specific supported syscalls match syscall_id { 0x1000 => { @@ -523,32 +571,40 @@ impl Cpu { let input_ptr = self.get_reg(10); let input_len = self.get_reg(11); let output_ptr = self.get_reg(12); - + // Validate pointers if !self.memory.is_valid_range(input_ptr, input_len) { - return Err(ExecutorError::OutOfBounds { addr: input_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: input_ptr, + }); } if !self.memory.is_valid_range(output_ptr, 32) { - return Err(ExecutorError::OutOfBounds { addr: output_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: output_ptr, + }); } - + // Extract input data - let input_data = self.memory.slice(input_ptr, input_len as usize) - .ok_or(ExecutorError::OutOfBounds { addr: input_ptr })?; - + let input_data = self + .memory + .slice(input_ptr, input_len as usize) + .ok_or(ExecutorError::OutOfBounds { + addr: input_ptr, + })?; + // Compute Keccak256 hash using delegation module let hash = zp1_delegation::keccak::keccak256(input_data); - + // Write output to memory self.memory.write_slice(output_ptr, &hash)?; - + // Record the delegation in trace mem_op = MemOp::Keccak256 { input_ptr, input_len, output_ptr, }; - + // Return success (a0 = 0) self.set_reg(10, 0); next_pc = self.pc.wrapping_add(4); @@ -559,19 +615,24 @@ impl Cpu { // a1 = output pointer (20 bytes: address) let input_ptr = self.get_reg(10); let output_ptr = self.get_reg(11); - + // Validate pointers if !self.memory.is_valid_range(input_ptr, 97) { - return Err(ExecutorError::OutOfBounds { addr: input_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: input_ptr, + }); } if !self.memory.is_valid_range(output_ptr, 20) { - return Err(ExecutorError::OutOfBounds { addr: output_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: output_ptr, + }); } - + // Extract input data - let input_data = self.memory.slice(input_ptr, 97) - .ok_or(ExecutorError::OutOfBounds { addr: input_ptr })?; - + let input_data = self.memory.slice(input_ptr, 97).ok_or( + ExecutorError::OutOfBounds { addr: input_ptr }, + )?; + let mut hash = [0u8; 32]; let mut r = [0u8; 32]; let mut s = [0u8; 32]; @@ -579,10 +640,11 @@ impl Cpu { let v = input_data[32]; r.copy_from_slice(&input_data[33..65]); s.copy_from_slice(&input_data[65..97]); - + // Perform ECRECOVER using delegation module - let address = zp1_delegation::ecrecover::ecrecover(&hash, v, &r, &s); - + let address = + zp1_delegation::ecrecover::ecrecover(&hash, v, &r, &s); + // Write output to memory match address { Some(addr) => { @@ -595,13 +657,13 @@ impl Cpu { self.set_reg(10, 1); // Failure } } - + // Record the delegation in trace mem_op = MemOp::Ecrecover { input_ptr, output_ptr, }; - + next_pc = self.pc.wrapping_add(4); } 0x1002 => { @@ -609,36 +671,44 @@ impl Cpu { // a0 = message pointer // a1 = message length // a2 = digest pointer (32 bytes) - + let message_ptr = self.get_reg(10); let message_len = self.get_reg(11); let digest_ptr = self.get_reg(12); - + // Validate pointers if !self.memory.is_valid_range(message_ptr, message_len) { - return Err(ExecutorError::OutOfBounds { addr: message_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: message_ptr, + }); } if !self.memory.is_valid_range(digest_ptr, 32) { - return Err(ExecutorError::OutOfBounds { addr: digest_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: digest_ptr, + }); } - + // Extract input data - let message = self.memory.slice(message_ptr, message_len as usize) - .ok_or(ExecutorError::OutOfBounds { addr: message_ptr })?; - + let message = self + .memory + .slice(message_ptr, message_len as usize) + .ok_or(ExecutorError::OutOfBounds { + addr: message_ptr, + })?; + // Compute SHA-256 hash using delegation module let digest = zp1_delegation::sha256::sha256(message); - + // Write digest to memory self.memory.write_slice(digest_ptr, &digest)?; - + // Record the delegation in trace mem_op = MemOp::Sha256 { message_ptr: message_ptr as usize, message_len: message_len as usize, digest_ptr: digest_ptr as usize, }; - + // Return success (a0 = 0) self.set_reg(10, 0); next_pc = self.pc.wrapping_add(4); @@ -648,36 +718,44 @@ impl Cpu { // a0 = message pointer // a1 = message length // a2 = digest pointer (20 bytes) - + let message_ptr = self.get_reg(10); let message_len = self.get_reg(11); let digest_ptr = self.get_reg(12); - + // Validate pointers if !self.memory.is_valid_range(message_ptr, message_len) { - return Err(ExecutorError::OutOfBounds { addr: message_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: message_ptr, + }); } if !self.memory.is_valid_range(digest_ptr, 20) { - return Err(ExecutorError::OutOfBounds { addr: digest_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: digest_ptr, + }); } - + // Extract input data - let message = self.memory.slice(message_ptr, message_len as usize) - .ok_or(ExecutorError::OutOfBounds { addr: message_ptr })?; - + let message = self + .memory + .slice(message_ptr, message_len as usize) + .ok_or(ExecutorError::OutOfBounds { + addr: message_ptr, + })?; + // Compute RIPEMD-160 hash using delegation module let digest = zp1_delegation::ripemd160::ripemd160(message); - + // Write digest to memory self.memory.write_slice(digest_ptr, &digest)?; - + // Record the delegation in trace mem_op = MemOp::Ripemd160 { message_ptr: message_ptr as usize, message_len: message_len as usize, digest_ptr: digest_ptr as usize, }; - + // Return success (a0 = 0) self.set_reg(10, 0); next_pc = self.pc.wrapping_add(4); @@ -688,57 +766,74 @@ impl Cpu { // a1 = exponent pointer (32 bytes) // a2 = modulus pointer (32 bytes) // a3 = result pointer (32 bytes) - + let base_ptr = self.get_reg(10); let exp_ptr = self.get_reg(11); let mod_ptr = self.get_reg(12); let result_ptr = self.get_reg(13); - + // Validate pointers if !self.memory.is_valid_range(base_ptr, 32) { - return Err(ExecutorError::OutOfBounds { addr: base_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: base_ptr, + }); } if !self.memory.is_valid_range(exp_ptr, 32) { - return Err(ExecutorError::OutOfBounds { addr: exp_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: exp_ptr, + }); } if !self.memory.is_valid_range(mod_ptr, 32) { - return Err(ExecutorError::OutOfBounds { addr: mod_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: mod_ptr, + }); } if !self.memory.is_valid_range(result_ptr, 32) { - return Err(ExecutorError::OutOfBounds { addr: result_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: result_ptr, + }); } - + // Extract input data - let base_bytes = self.memory.slice(base_ptr, 32) + let base_bytes = self + .memory + .slice(base_ptr, 32) .ok_or(ExecutorError::OutOfBounds { addr: base_ptr })?; - let exp_bytes = self.memory.slice(exp_ptr, 32) + let exp_bytes = self + .memory + .slice(exp_ptr, 32) .ok_or(ExecutorError::OutOfBounds { addr: exp_ptr })?; - let mod_bytes = self.memory.slice(mod_ptr, 32) + let mod_bytes = self + .memory + .slice(mod_ptr, 32) .ok_or(ExecutorError::OutOfBounds { addr: mod_ptr })?; - + // Convert to U256 let base = zp1_delegation::bigint::U256::from_le_bytes( - base_bytes.try_into().unwrap() + base_bytes.try_into().unwrap(), ); let exponent = zp1_delegation::bigint::U256::from_le_bytes( - exp_bytes.try_into().unwrap() + exp_bytes.try_into().unwrap(), ); let modulus = zp1_delegation::bigint::U256::from_le_bytes( - mod_bytes.try_into().unwrap() + mod_bytes.try_into().unwrap(), ); - + // Compute modular exponentiation using delegation - let delegation_call = zp1_delegation::bigint::delegate_u256_modexp( - &base, &exponent, &modulus - ); - + let delegation_call = + zp1_delegation::bigint::delegate_u256_modexp( + &base, &exponent, &modulus, + ); + // Convert M31 output limbs back to U256 - let result = zp1_delegation::bigint::U256::from_m31_limbs(&delegation_call.output); - + let result = zp1_delegation::bigint::U256::from_m31_limbs( + &delegation_call.output, + ); + // Write result to memory let result_bytes = result.to_le_bytes(); self.memory.write_slice(result_ptr, &result_bytes)?; - + // Record the delegation in trace mem_op = MemOp::Modexp { base_ptr: base_ptr as usize, @@ -746,7 +841,7 @@ impl Cpu { mod_ptr: mod_ptr as usize, result_ptr: result_ptr as usize, }; - + // Return success (a0 = 0) self.set_reg(10, 0); next_pc = self.pc.wrapping_add(4); @@ -756,47 +851,61 @@ impl Cpu { // a0 = message pointer // a1 = message length // a2 = digest pointer (64 bytes) - + let message_ptr = self.get_reg(10); let message_len = self.get_reg(11); let digest_ptr = self.get_reg(12); - + // Validate pointers if !self.memory.is_valid_range(message_ptr, message_len) { - return Err(ExecutorError::OutOfBounds { addr: message_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: message_ptr, + }); } if !self.memory.is_valid_range(digest_ptr, 64) { - return Err(ExecutorError::OutOfBounds { addr: digest_ptr }); + return Err(ExecutorError::OutOfBounds { + addr: digest_ptr, + }); } - + // Extract input data - let message = self.memory.slice(message_ptr, message_len as usize) - .ok_or(ExecutorError::OutOfBounds { addr: message_ptr })?; - + let message = self + .memory + .slice(message_ptr, message_len as usize) + .ok_or(ExecutorError::OutOfBounds { + addr: message_ptr, + })?; + // Compute Blake2b hash using delegation module let digest = zp1_delegation::blake2b::blake2b(message); - + // Write digest to memory self.memory.write_slice(digest_ptr, &digest)?; - + // Record the delegation in trace mem_op = MemOp::Blake2b { message_ptr: message_ptr as usize, message_len: message_len as usize, digest_ptr: digest_ptr as usize, }; - + // Return success (a0 = 0) self.set_reg(10, 0); next_pc = self.pc.wrapping_add(4); } 93 => { // Linux exit syscall - allow this for program termination - return Err(ExecutorError::Ecall { pc: self.pc, syscall_id }); + return Err(ExecutorError::Ecall { + pc: self.pc, + syscall_id, + }); } _ => { // Unsupported syscall - return Err(ExecutorError::Ecall { pc: self.pc, syscall_id }); + return Err(ExecutorError::Ecall { + pc: self.pc, + syscall_id, + }); } } } @@ -811,21 +920,37 @@ impl Cpu { } 0x302 => { // MRET - Machine Return (not supported in single-mode) - return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }); + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }); } _ => { // Unknown privileged instruction - return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }); + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }); } } } // CSR instructions are not supported in this minimal implementation - system_funct3::CSRRW | system_funct3::CSRRS | system_funct3::CSRRC | - system_funct3::CSRRWI | system_funct3::CSRRSI | system_funct3::CSRRCI => { - return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }); + system_funct3::CSRRW + | system_funct3::CSRRS + | system_funct3::CSRRC + | system_funct3::CSRRWI + | system_funct3::CSRRSI + | system_funct3::CSRRCI => { + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }); } _ => { - return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }); + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }); } } } @@ -836,18 +961,25 @@ impl Cpu { // In single-threaded deterministic execution, these are NOPs // No memory reordering or cache coherency needed // Just advance to next instruction (already set above) - + // Mark in flags for tracing purposes flags.is_alu = false; // Not an ALU op, just a NOP-like fence } _ => { - return Err(ExecutorError::InvalidInstruction { pc: self.pc, bits: instr_bits }); + return Err(ExecutorError::InvalidInstruction { + pc: self.pc, + bits: instr_bits, + }); } } // Write back to register (x0 writes are ignored, stores don't write, branches don't write) - if instr.rd != 0 && !instr.is_store() && !instr.is_branch() && instr.opcode != opcode::SYSTEM { + if instr.rd != 0 + && !instr.is_store() + && !instr.is_branch() + && instr.opcode != opcode::SYSTEM + { self.set_reg(instr.rd, rd_val); } @@ -902,7 +1034,10 @@ impl Cpu { return Ok(trace); } // For other syscalls, return the error - return Err(ExecutorError::Ecall { pc: self.pc, syscall_id }); + return Err(ExecutorError::Ecall { + pc: self.pc, + syscall_id, + }); } Err(e) => return Err(e), } @@ -1093,9 +1228,9 @@ mod tests { fn test_add() { let mut cpu = Cpu::with_memory_size(4096); let program: Vec = [ - assemble_addi(1, 0, 10), // x1 = 10 - assemble_addi(2, 0, 20), // x2 = 20 - assemble_add(3, 1, 2), // x3 = x1 + x2 = 30 + assemble_addi(1, 0, 10), // x1 = 10 + assemble_addi(2, 0, 20), // x2 = 20 + assemble_add(3, 1, 2), // x3 = x1 + x2 = 30 ] .iter() .flat_map(|i| i.to_le_bytes()) @@ -1161,15 +1296,15 @@ mod tests { cpu.load_program(0x100, &program).unwrap(); cpu.step().unwrap(); assert_eq!(cpu.get_reg(1), 0x104); // Return address - assert_eq!(cpu.pc, 0x108); // Jump target + assert_eq!(cpu.pc, 0x108); // Jump target } #[test] fn test_jalr() { let mut cpu = Cpu::with_memory_size(4096); let program: Vec = [ - assemble_addi(2, 0, 0x200), // x2 = 0x200 - assemble_jalr(1, 2, 4), // x1 = pc+4; pc = x2 + 4 = 0x204 + assemble_addi(2, 0, 0x200), // x2 = 0x200 + assemble_jalr(1, 2, 4), // x1 = pc+4; pc = x2 + 4 = 0x204 ] .iter() .flat_map(|i| i.to_le_bytes()) @@ -1178,8 +1313,8 @@ mod tests { cpu.load_program(0, &program).unwrap(); cpu.step().unwrap(); cpu.step().unwrap(); - assert_eq!(cpu.get_reg(1), 8); // Return address (0 + 4 + 4) - assert_eq!(cpu.pc, 0x204); // Jump target + assert_eq!(cpu.get_reg(1), 8); // Return address (0 + 4 + 4) + assert_eq!(cpu.pc, 0x204); // Jump target } #[test] @@ -1188,7 +1323,7 @@ mod tests { let program: Vec = [ assemble_addi(1, 0, 5), assemble_addi(2, 0, 5), - assemble_beq(1, 2, 8), // Branch if x1 == x2 (should branch) + assemble_beq(1, 2, 8), // Branch if x1 == x2 (should branch) ] .iter() .flat_map(|i| i.to_le_bytes()) @@ -1207,7 +1342,7 @@ mod tests { let program: Vec = [ assemble_addi(1, 0, 5), assemble_addi(2, 0, 10), - assemble_beq(1, 2, 8), // Branch if x1 == x2 (should NOT branch) + assemble_beq(1, 2, 8), // Branch if x1 == x2 (should NOT branch) ] .iter() .flat_map(|i| i.to_le_bytes()) @@ -1224,9 +1359,9 @@ mod tests { fn test_load_store() { let mut cpu = Cpu::with_memory_size(4096); let program: Vec = [ - assemble_addi(1, 0, 0x42), // x1 = 0x42 - assemble_sw(0, 1, 0x100), // mem[0x100] = x1 - assemble_lw(2, 0, 0x100), // x2 = mem[0x100] + assemble_addi(1, 0, 0x42), // x1 = 0x42 + assemble_sw(0, 1, 0x100), // mem[0x100] = x1 + assemble_lw(2, 0, 0x100), // x2 = mem[0x100] ] .iter() .flat_map(|i| i.to_le_bytes()) @@ -1337,7 +1472,7 @@ mod tests { fn test_sra() { let mut cpu = Cpu::with_memory_size(4096); let program: Vec = [ - assemble_addi(1, 0, -16), // x1 = 0xFFFFFFF0 + assemble_addi(1, 0, -16), // x1 = 0xFFFFFFF0 assemble_addi(2, 0, 2), assemble_r(opcode::OP, 3, 0b101, 1, 2, 0x20), // SRA x3, x1, x2 ] @@ -1393,8 +1528,8 @@ mod tests { fn test_blt() { let mut cpu = Cpu::with_memory_size(4096); let program: Vec = [ - assemble_addi(1, 0, -5), // x1 = -5 (signed) - assemble_addi(2, 0, 5), // x2 = 5 + assemble_addi(1, 0, -5), // x1 = -5 (signed) + assemble_addi(2, 0, 5), // x2 = 5 assemble_b(opcode::BRANCH, 0b100, 1, 2, 8), // BLT x1, x2, 8 ] .iter() @@ -1412,8 +1547,8 @@ mod tests { fn test_bltu() { let mut cpu = Cpu::with_memory_size(4096); let program: Vec = [ - assemble_addi(1, 0, -5), // x1 = 0xFFFFFFFB (unsigned) - assemble_addi(2, 0, 5), // x2 = 5 + assemble_addi(1, 0, -5), // x1 = 0xFFFFFFFB (unsigned) + assemble_addi(2, 0, 5), // x2 = 5 assemble_b(opcode::BRANCH, 0b110, 1, 2, 8), // BLTU x1, x2, 8 ] .iter() @@ -1432,7 +1567,7 @@ mod tests { let mut cpu = Cpu::with_memory_size(4096); // Store 0xFF at address 0x100 cpu.memory.write_u8(0x100, 0xFF).unwrap(); - + let program: Vec = [ assemble_i(opcode::LOAD, 1, 0b000, 0, 0x100), // LB x1, 0x100(x0) ] @@ -1449,7 +1584,7 @@ mod tests { fn test_load_byte_unsigned() { let mut cpu = Cpu::with_memory_size(4096); cpu.memory.write_u8(0x100, 0xFF).unwrap(); - + let program: Vec = [ assemble_i(opcode::LOAD, 1, 0b100, 0, 0x100), // LBU x1, 0x100(x0) ] @@ -1476,11 +1611,11 @@ mod tests { cpu.load_program(0, &program).unwrap(); cpu.enable_tracing(); - + cpu.step().unwrap(); cpu.step().unwrap(); cpu.step().unwrap(); - + let trace = cpu.take_trace().unwrap(); assert_eq!(trace.len(), 3); assert_eq!(trace.rows[0].rd_val, 10); diff --git a/crates/executor/src/decode.rs b/crates/executor/src/decode.rs index decfe9f..1f6fe46 100644 --- a/crates/executor/src/decode.rs +++ b/crates/executor/src/decode.rs @@ -34,12 +34,12 @@ pub struct DecodedInstr { /// RISC-V instruction formats. #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] pub enum InstrFormat { - R, // Register-register (ADD, SUB, MUL, etc.) - I, // Immediate (ADDI, LOAD, JALR) - S, // Store (SW, SH, SB) - B, // Branch (BEQ, BNE, BLT, etc.) - U, // Upper immediate (LUI, AUIPC) - J, // Jump (JAL) + R, // Register-register (ADD, SUB, MUL, etc.) + I, // Immediate (ADDI, LOAD, JALR) + S, // Store (SW, SH, SB) + B, // Branch (BEQ, BNE, BLT, etc.) + U, // Upper immediate (LUI, AUIPC) + J, // Jump (JAL) } /// Opcode constants for RV32IM. @@ -120,7 +120,7 @@ pub mod op_funct3 { /// funct3 values for SYSTEM instructions pub mod system_funct3 { - pub const PRIV: u8 = 0b000; // ECALL, EBREAK, WFI, MRET, etc. + pub const PRIV: u8 = 0b000; // ECALL, EBREAK, WFI, MRET, etc. pub const CSRRW: u8 = 0b001; pub const CSRRS: u8 = 0b010; pub const CSRRC: u8 = 0b011; @@ -214,8 +214,12 @@ impl DecodedInstr { /// Check if this is a NOP (ADDI x0, x0, 0). #[inline] pub fn is_nop(&self) -> bool { - self.bits == 0x00000013 || - (self.opcode == opcode::OP_IMM && self.funct3 == 0 && self.rd == 0 && self.rs1 == 0 && self.imm == 0) + self.bits == 0x00000013 + || (self.opcode == opcode::OP_IMM + && self.funct3 == 0 + && self.rd == 0 + && self.rs1 == 0 + && self.imm == 0) } /// Check if this is an M-extension instruction (multiply/divide). @@ -303,7 +307,13 @@ impl DecodedInstr { _ => "STORE?", }, opcode::OP_IMM => match self.funct3 { - op_imm_funct3::ADDI => if self.is_nop() { "NOP" } else { "ADDI" }, + op_imm_funct3::ADDI => { + if self.is_nop() { + "NOP" + } else { + "ADDI" + } + } op_imm_funct3::SLTI => "SLTI", op_imm_funct3::SLTIU => "SLTIU", op_imm_funct3::XORI => "XORI", @@ -311,8 +321,12 @@ impl DecodedInstr { op_imm_funct3::ANDI => "ANDI", op_imm_funct3::SLLI => "SLLI", op_imm_funct3::SRLI_SRAI => { - if self.funct7 & 0x20 != 0 { "SRAI" } else { "SRLI" } - }, + if self.funct7 & 0x20 != 0 { + "SRAI" + } else { + "SRLI" + } + } _ => "OP_IMM?", }, opcode::OP => { @@ -343,16 +357,14 @@ impl DecodedInstr { _ => "OP?", } } - }, + } opcode::SYSTEM => match self.funct3 { - system_funct3::PRIV => { - match self.imm as u32 & 0xFFF { - 0x000 => "ECALL", - 0x001 => "EBREAK", - 0x105 => "WFI", - 0x302 => "MRET", - _ => "PRIV?", - } + system_funct3::PRIV => match self.imm as u32 & 0xFFF { + 0x000 => "ECALL", + 0x001 => "EBREAK", + 0x105 => "WFI", + 0x302 => "MRET", + _ => "PRIV?", }, system_funct3::CSRRW => "CSRRW", system_funct3::CSRRS => "CSRRS", @@ -363,8 +375,12 @@ impl DecodedInstr { _ => "SYSTEM?", }, opcode::MISC_MEM => { - if self.funct3 == 0 { "FENCE" } else { "FENCE.I" } - }, + if self.funct3 == 0 { + "FENCE" + } else { + "FENCE.I" + } + } _ => "???", } } diff --git a/crates/executor/src/elf.rs b/crates/executor/src/elf.rs index 52b1bcc..4306316 100644 --- a/crates/executor/src/elf.rs +++ b/crates/executor/src/elf.rs @@ -283,38 +283,43 @@ impl ElfLoader { pub fn parse(data: &[u8]) -> Result { // Validate minimum size for ELF header if data.len() < ELF32_HEADER_SIZE { - return Err(ExecutorError::InvalidElf( - format!("File too small: {} bytes (need at least {})", data.len(), ELF32_HEADER_SIZE) - )); + return Err(ExecutorError::InvalidElf(format!( + "File too small: {} bytes (need at least {})", + data.len(), + ELF32_HEADER_SIZE + ))); } // Validate ELF magic number if data[0..4] != ELF_MAGIC { - return Err(ExecutorError::InvalidElf( - format!("Invalid magic: {:02x} {:02x} {:02x} {:02x}", - data[0], data[1], data[2], data[3]) - )); + return Err(ExecutorError::InvalidElf(format!( + "Invalid magic: {:02x} {:02x} {:02x} {:02x}", + data[0], data[1], data[2], data[3] + ))); } // Validate 32-bit class if data[4] != ELFCLASS32 { - return Err(ExecutorError::InvalidElf( - format!("Not 32-bit ELF (class: {})", data[4]) - )); + return Err(ExecutorError::InvalidElf(format!( + "Not 32-bit ELF (class: {})", + data[4] + ))); } // Validate little-endian encoding if data[5] != ELFDATA2LSB { - return Err(ExecutorError::InvalidElf( - format!("Not little-endian (encoding: {})", data[5]) - )); + return Err(ExecutorError::InvalidElf(format!( + "Not little-endian (encoding: {})", + data[5] + ))); } // Validate ELF version in e_ident if data[6] != EV_CURRENT { - return Err(ExecutorError::InvalidElf( - format!("Unsupported ELF version in ident: {}", data[6]) - )); + return Err(ExecutorError::InvalidElf(format!( + "Unsupported ELF version in ident: {}", + data[6] + ))); } // Parse header fields @@ -324,23 +329,26 @@ impl ElfLoader { // Validate ELF type (executable or shared object) if e_type != ET_EXEC && e_type != ET_DYN { - return Err(ExecutorError::InvalidElf( - format!("Not an executable (type: {})", e_type) - )); + return Err(ExecutorError::InvalidElf(format!( + "Not an executable (type: {})", + e_type + ))); } // Validate machine type if e_machine != EM_RISCV { - return Err(ExecutorError::InvalidElf( - format!("Not RISC-V (machine: {})", e_machine) - )); + return Err(ExecutorError::InvalidElf(format!( + "Not RISC-V (machine: {})", + e_machine + ))); } // Validate version if e_version != 1 { - return Err(ExecutorError::InvalidElf( - format!("Unsupported ELF version: {}", e_version) - )); + return Err(ExecutorError::InvalidElf(format!( + "Unsupported ELF version: {}", + e_version + ))); } // Parse full header @@ -362,9 +370,10 @@ impl ElfLoader { // Validate header size if header.ehsize as usize != ELF32_HEADER_SIZE { - return Err(ExecutorError::InvalidElf( - format!("Invalid ELF header size: {}", header.ehsize) - )); + return Err(ExecutorError::InvalidElf(format!( + "Invalid ELF header size: {}", + header.ehsize + ))); } // Parse program headers @@ -392,8 +401,8 @@ impl ElfLoader { /// Parse program headers from ELF data. fn parse_program_headers( - data: &[u8], - header: &Elf32Header + data: &[u8], + header: &Elf32Header, ) -> Result, ExecutorError> { let mut headers = Vec::with_capacity(header.phnum as usize); let phoff = header.phoff as usize; @@ -401,29 +410,71 @@ impl ElfLoader { // Validate program header entry size if phentsize < ELF32_PHDR_SIZE { - return Err(ExecutorError::InvalidElf( - format!("Program header size too small: {}", phentsize) - )); + return Err(ExecutorError::InvalidElf(format!( + "Program header size too small: {}", + phentsize + ))); } for i in 0..header.phnum as usize { let offset = phoff + i * phentsize; - + if offset + ELF32_PHDR_SIZE > data.len() { - return Err(ExecutorError::InvalidElf( - format!("Program header {} out of bounds (offset {})", i, offset) - )); + return Err(ExecutorError::InvalidElf(format!( + "Program header {} out of bounds (offset {})", + i, offset + ))); } let ph = Elf32ProgramHeader { - p_type: u32::from_le_bytes([data[offset], data[offset + 1], data[offset + 2], data[offset + 3]]), - p_offset: u32::from_le_bytes([data[offset + 4], data[offset + 5], data[offset + 6], data[offset + 7]]), - p_vaddr: u32::from_le_bytes([data[offset + 8], data[offset + 9], data[offset + 10], data[offset + 11]]), - p_paddr: u32::from_le_bytes([data[offset + 12], data[offset + 13], data[offset + 14], data[offset + 15]]), - p_filesz: u32::from_le_bytes([data[offset + 16], data[offset + 17], data[offset + 18], data[offset + 19]]), - p_memsz: u32::from_le_bytes([data[offset + 20], data[offset + 21], data[offset + 22], data[offset + 23]]), - p_flags: u32::from_le_bytes([data[offset + 24], data[offset + 25], data[offset + 26], data[offset + 27]]), - p_align: u32::from_le_bytes([data[offset + 28], data[offset + 29], data[offset + 30], data[offset + 31]]), + p_type: u32::from_le_bytes([ + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3], + ]), + p_offset: u32::from_le_bytes([ + data[offset + 4], + data[offset + 5], + data[offset + 6], + data[offset + 7], + ]), + p_vaddr: u32::from_le_bytes([ + data[offset + 8], + data[offset + 9], + data[offset + 10], + data[offset + 11], + ]), + p_paddr: u32::from_le_bytes([ + data[offset + 12], + data[offset + 13], + data[offset + 14], + data[offset + 15], + ]), + p_filesz: u32::from_le_bytes([ + data[offset + 16], + data[offset + 17], + data[offset + 18], + data[offset + 19], + ]), + p_memsz: u32::from_le_bytes([ + data[offset + 20], + data[offset + 21], + data[offset + 22], + data[offset + 23], + ]), + p_flags: u32::from_le_bytes([ + data[offset + 24], + data[offset + 25], + data[offset + 26], + data[offset + 27], + ]), + p_align: u32::from_le_bytes([ + data[offset + 28], + data[offset + 29], + data[offset + 30], + data[offset + 31], + ]), }; headers.push(ph); @@ -435,7 +486,7 @@ impl ElfLoader { /// Parse section headers from ELF data. fn parse_section_headers( data: &[u8], - header: &Elf32Header + header: &Elf32Header, ) -> Result, ExecutorError> { // Section headers are optional if header.shoff == 0 || header.shnum == 0 { @@ -448,9 +499,10 @@ impl ElfLoader { // Validate section header entry size if shentsize < ELF32_SHDR_SIZE { - return Err(ExecutorError::InvalidElf( - format!("Section header size too small: {}", shentsize) - )); + return Err(ExecutorError::InvalidElf(format!( + "Section header size too small: {}", + shentsize + ))); } for i in 0..header.shnum as usize { @@ -462,16 +514,66 @@ impl ElfLoader { } let sh = Elf32SectionHeader { - sh_name: u32::from_le_bytes([data[offset], data[offset + 1], data[offset + 2], data[offset + 3]]), - sh_type: u32::from_le_bytes([data[offset + 4], data[offset + 5], data[offset + 6], data[offset + 7]]), - sh_flags: u32::from_le_bytes([data[offset + 8], data[offset + 9], data[offset + 10], data[offset + 11]]), - sh_addr: u32::from_le_bytes([data[offset + 12], data[offset + 13], data[offset + 14], data[offset + 15]]), - sh_offset: u32::from_le_bytes([data[offset + 16], data[offset + 17], data[offset + 18], data[offset + 19]]), - sh_size: u32::from_le_bytes([data[offset + 20], data[offset + 21], data[offset + 22], data[offset + 23]]), - sh_link: u32::from_le_bytes([data[offset + 24], data[offset + 25], data[offset + 26], data[offset + 27]]), - sh_info: u32::from_le_bytes([data[offset + 28], data[offset + 29], data[offset + 30], data[offset + 31]]), - sh_addralign: u32::from_le_bytes([data[offset + 32], data[offset + 33], data[offset + 34], data[offset + 35]]), - sh_entsize: u32::from_le_bytes([data[offset + 36], data[offset + 37], data[offset + 38], data[offset + 39]]), + sh_name: u32::from_le_bytes([ + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3], + ]), + sh_type: u32::from_le_bytes([ + data[offset + 4], + data[offset + 5], + data[offset + 6], + data[offset + 7], + ]), + sh_flags: u32::from_le_bytes([ + data[offset + 8], + data[offset + 9], + data[offset + 10], + data[offset + 11], + ]), + sh_addr: u32::from_le_bytes([ + data[offset + 12], + data[offset + 13], + data[offset + 14], + data[offset + 15], + ]), + sh_offset: u32::from_le_bytes([ + data[offset + 16], + data[offset + 17], + data[offset + 18], + data[offset + 19], + ]), + sh_size: u32::from_le_bytes([ + data[offset + 20], + data[offset + 21], + data[offset + 22], + data[offset + 23], + ]), + sh_link: u32::from_le_bytes([ + data[offset + 24], + data[offset + 25], + data[offset + 26], + data[offset + 27], + ]), + sh_info: u32::from_le_bytes([ + data[offset + 28], + data[offset + 29], + data[offset + 30], + data[offset + 31], + ]), + sh_addralign: u32::from_le_bytes([ + data[offset + 32], + data[offset + 33], + data[offset + 34], + data[offset + 35], + ]), + sh_entsize: u32::from_le_bytes([ + data[offset + 36], + data[offset + 37], + data[offset + 38], + data[offset + 39], + ]), }; headers.push(sh); @@ -484,7 +586,7 @@ impl ElfLoader { fn load_section_strtab( data: &[u8], header: &Elf32Header, - sections: &[Elf32SectionHeader] + sections: &[Elf32SectionHeader], ) -> Option> { let idx = header.shstrndx as usize; if idx >= sections.len() { @@ -498,7 +600,7 @@ impl ElfLoader { let start = strtab.sh_offset as usize; let size = strtab.sh_size as usize; - + if start + size <= data.len() { Some(data[start..start + size].to_vec()) } else { @@ -509,10 +611,11 @@ impl ElfLoader { /// Parse symbol table from section headers. fn parse_symbol_table( data: &[u8], - sections: &[Elf32SectionHeader] + sections: &[Elf32SectionHeader], ) -> (Vec, Option>) { // Find symbol table section (.symtab or .dynsym) - let symtab = sections.iter() + let symtab = sections + .iter() .find(|s| s.sh_type == SHT_SYMTAB || s.sh_type == SHT_DYNSYM); let symtab = match symtab { @@ -539,10 +642,10 @@ impl ElfLoader { let mut symbols = Vec::new(); let start = symtab.sh_offset as usize; let size = symtab.sh_size as usize; - let entsize = if symtab.sh_entsize > 0 { - symtab.sh_entsize as usize - } else { - ELF32_SYM_SIZE + let entsize = if symtab.sh_entsize > 0 { + symtab.sh_entsize as usize + } else { + ELF32_SYM_SIZE }; if entsize < ELF32_SYM_SIZE || start + size > data.len() { @@ -557,9 +660,24 @@ impl ElfLoader { } let sym = Elf32Symbol { - st_name: u32::from_le_bytes([data[offset], data[offset + 1], data[offset + 2], data[offset + 3]]), - st_value: u32::from_le_bytes([data[offset + 4], data[offset + 5], data[offset + 6], data[offset + 7]]), - st_size: u32::from_le_bytes([data[offset + 8], data[offset + 9], data[offset + 10], data[offset + 11]]), + st_name: u32::from_le_bytes([ + data[offset], + data[offset + 1], + data[offset + 2], + data[offset + 3], + ]), + st_value: u32::from_le_bytes([ + data[offset + 4], + data[offset + 5], + data[offset + 6], + data[offset + 7], + ]), + st_size: u32::from_le_bytes([ + data[offset + 8], + data[offset + 9], + data[offset + 10], + data[offset + 11], + ]), st_info: data[offset + 12], st_other: data[offset + 13], st_shndx: u16::from_le_bytes([data[offset + 14], data[offset + 15]]), @@ -578,7 +696,9 @@ impl ElfLoader { /// Get all loadable segments (PT_LOAD). pub fn loadable_segments(&self) -> impl Iterator { - self.program_headers.iter().filter(|ph| ph.p_type == PT_LOAD) + self.program_headers + .iter() + .filter(|ph| ph.p_type == PT_LOAD) } /// Load the ELF into memory. @@ -601,16 +721,18 @@ impl ElfLoader { // Validate file bounds if file_size > 0 && file_offset.saturating_add(file_size) > self.data.len() { - return Err(ExecutorError::InvalidElf( - format!("Segment at 0x{:08x} data out of bounds", mem_addr) - )); + return Err(ExecutorError::InvalidElf(format!( + "Segment at 0x{:08x} data out of bounds", + mem_addr + ))); } // Validate memory size is at least file size if mem_size < file_size { - return Err(ExecutorError::InvalidElf( - format!("Segment at 0x{:08x} has memsz < filesz", mem_addr) - )); + return Err(ExecutorError::InvalidElf(format!( + "Segment at 0x{:08x} has memsz < filesz", + mem_addr + ))); } // Load file data into memory @@ -682,13 +804,14 @@ impl ElfLoader { if start >= strtab.len() { return None; } - + // Find null terminator - let end = strtab[start..].iter() + let end = strtab[start..] + .iter() .position(|&b| b == 0) .map(|pos| start + pos) .unwrap_or(strtab.len()); - + std::str::from_utf8(&strtab[start..end]).ok() } @@ -700,7 +823,8 @@ impl ElfLoader { return None; } - let end = strtab[start..].iter() + let end = strtab[start..] + .iter() .position(|&b| b == 0) .map(|pos| start + pos) .unwrap_or(strtab.len()); @@ -710,7 +834,8 @@ impl ElfLoader { /// Find a symbol by name. pub fn find_symbol(&self, name: &str) -> Option<&Elf32Symbol> { - self.symbols.iter() + self.symbols + .iter() .find(|s| self.symbol_name(s.st_name) == Some(name)) } @@ -815,16 +940,16 @@ pub mod symbol_type { pub fn build_test_elf(code: &[u8], entry: u32, load_addr: u32) -> Vec { // Align code segment to 4 bytes let code_padded_len = (code.len() + 3) & !3; - + let mut elf = Vec::with_capacity(ELF32_HEADER_SIZE + ELF32_PHDR_SIZE + code_padded_len); - + // ELF header (52 bytes) - elf.extend_from_slice(&ELF_MAGIC); // e_ident[0..4]: Magic - elf.push(ELFCLASS32); // e_ident[4]: Class (32-bit) - elf.push(ELFDATA2LSB); // e_ident[5]: Data (little-endian) - elf.push(EV_CURRENT); // e_ident[6]: Version - elf.push(0); // e_ident[7]: OS/ABI (SYSV) - elf.extend_from_slice(&[0u8; 8]); // e_ident[8..16]: Padding + elf.extend_from_slice(&ELF_MAGIC); // e_ident[0..4]: Magic + elf.push(ELFCLASS32); // e_ident[4]: Class (32-bit) + elf.push(ELFDATA2LSB); // e_ident[5]: Data (little-endian) + elf.push(EV_CURRENT); // e_ident[6]: Version + elf.push(0); // e_ident[7]: OS/ABI (SYSV) + elf.extend_from_slice(&[0u8; 8]); // e_ident[8..16]: Padding elf.extend_from_slice(&ET_EXEC.to_le_bytes()); // e_type: Executable elf.extend_from_slice(&EM_RISCV.to_le_bytes()); // e_machine: RISC-V elf.extend_from_slice(&1u32.to_le_bytes()); // e_version @@ -838,7 +963,7 @@ pub fn build_test_elf(code: &[u8], entry: u32, load_addr: u32) -> Vec { elf.extend_from_slice(&(ELF32_SHDR_SIZE as u16).to_le_bytes()); // e_shentsize elf.extend_from_slice(&0u16.to_le_bytes()); // e_shnum (no sections) elf.extend_from_slice(&0u16.to_le_bytes()); // e_shstrndx - + // Program header (32 bytes) let code_offset = ELF32_HEADER_SIZE + ELF32_PHDR_SIZE; elf.extend_from_slice(&PT_LOAD.to_le_bytes()); // p_type @@ -849,24 +974,24 @@ pub fn build_test_elf(code: &[u8], entry: u32, load_addr: u32) -> Vec { elf.extend_from_slice(&(code.len() as u32).to_le_bytes()); // p_memsz elf.extend_from_slice(&(segment_flags::PF_R | segment_flags::PF_X).to_le_bytes()); // p_flags elf.extend_from_slice(&4u32.to_le_bytes()); // p_align - + // Code segment elf.extend_from_slice(code); - + // Pad to 4-byte alignment while elf.len() % 4 != 0 { elf.push(0); } - + elf } /// Build an ELF with multiple segments (code + data + BSS) for testing. pub fn build_test_elf_with_data( - code: &[u8], + code: &[u8], data: &[u8], bss_size: u32, - entry: u32, + entry: u32, code_addr: u32, data_addr: u32, ) -> Vec { @@ -874,9 +999,9 @@ pub fn build_test_elf_with_data( let phdrs_size = num_phdrs * ELF32_PHDR_SIZE; let code_offset = ELF32_HEADER_SIZE + phdrs_size; let data_offset = code_offset + ((code.len() + 3) & !3); // Aligned - + let mut elf = Vec::new(); - + // ELF header elf.extend_from_slice(&ELF_MAGIC); elf.push(ELFCLASS32); @@ -897,7 +1022,7 @@ pub fn build_test_elf_with_data( elf.extend_from_slice(&(ELF32_SHDR_SIZE as u16).to_le_bytes()); elf.extend_from_slice(&0u16.to_le_bytes()); elf.extend_from_slice(&0u16.to_le_bytes()); - + // Program header 1: Code (PT_LOAD, R+X) elf.extend_from_slice(&PT_LOAD.to_le_bytes()); elf.extend_from_slice(&(code_offset as u32).to_le_bytes()); @@ -907,7 +1032,7 @@ pub fn build_test_elf_with_data( elf.extend_from_slice(&(code.len() as u32).to_le_bytes()); elf.extend_from_slice(&(segment_flags::PF_R | segment_flags::PF_X).to_le_bytes()); elf.extend_from_slice(&4u32.to_le_bytes()); - + // Program header 2: Data + BSS (PT_LOAD, R+W) let memsz = data.len() as u32 + bss_size; elf.extend_from_slice(&PT_LOAD.to_le_bytes()); @@ -918,16 +1043,16 @@ pub fn build_test_elf_with_data( elf.extend_from_slice(&memsz.to_le_bytes()); // memsz (includes BSS) elf.extend_from_slice(&(segment_flags::PF_R | segment_flags::PF_W).to_le_bytes()); elf.extend_from_slice(&4u32.to_le_bytes()); - + // Code segment elf.extend_from_slice(code); while elf.len() < data_offset { elf.push(0); } - + // Data segment elf.extend_from_slice(data); - + elf } @@ -946,10 +1071,10 @@ mod tests { 0x93, 0x00, 0xa0, 0x02, // addi x1, x0, 42 0x73, 0x00, 0x00, 0x00, // ecall ]; - + let elf_data = build_test_elf(&code, 0x1000, 0x1000); let loader = ElfLoader::parse(&elf_data).expect("Failed to parse ELF"); - + assert_eq!(loader.entry_point(), 0x1000); assert_eq!(loader.loadable_segments().count(), 1); assert_eq!(loader.header().e_type, ET_EXEC); @@ -962,15 +1087,15 @@ mod tests { 0x93, 0x00, 0xa0, 0x02, // addi x1, x0, 42 0x73, 0x00, 0x00, 0x00, // ecall ]; - + let elf_data = build_test_elf(&code, 0x1000, 0x1000); let loader = ElfLoader::parse(&elf_data).unwrap(); - + let mut memory = Memory::with_default_size(); let entry = loader.load_into_memory(&mut memory).unwrap(); - + assert_eq!(entry, 0x1000); - + // Verify code was loaded let instr = memory.read_u32(0x1000).unwrap(); assert_eq!(instr, 0x02a00093); // addi x1, x0, 42 @@ -987,8 +1112,11 @@ mod tests { let result = ElfLoader::parse(&bad_data); assert!(result.is_err()); let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("magic") || err_msg.contains("Magic"), - "Expected 'magic' in error: {}", err_msg); + assert!( + err_msg.contains("magic") || err_msg.contains("Magic"), + "Expected 'magic' in error: {}", + err_msg + ); } #[test] @@ -997,8 +1125,11 @@ mod tests { let result = ElfLoader::parse(&bad_data); assert!(result.is_err()); let err_msg = result.unwrap_err().to_string(); - assert!(err_msg.contains("small") || err_msg.contains("bytes"), - "Expected size-related error: {}", err_msg); + assert!( + err_msg.contains("small") || err_msg.contains("bytes"), + "Expected size-related error: {}", + err_msg + ); } #[test] @@ -1006,7 +1137,7 @@ mod tests { let code = vec![0x00; 100]; let elf_data = build_test_elf(&code, 0x2000, 0x2000); let loader = ElfLoader::parse(&elf_data).unwrap(); - + let (low, high) = loader.memory_bounds(); assert_eq!(low, 0x2000); assert_eq!(high, 0x2000 + 100); @@ -1020,33 +1151,35 @@ mod tests { ]; let data = vec![0x11, 0x22, 0x33, 0x44]; let bss_size = 16; - - let elf_data = build_test_elf_with_data( - &code, &data, bss_size, - 0x1000, 0x1000, 0x2000 - ); - + + let elf_data = build_test_elf_with_data(&code, &data, bss_size, 0x1000, 0x1000, 0x2000); + let loader = ElfLoader::parse(&elf_data).unwrap(); - + assert_eq!(loader.entry_point(), 0x1000); assert_eq!(loader.loadable_segments().count(), 2); - + let mut memory = Memory::with_default_size(); loader.load_into_memory(&mut memory).unwrap(); - + // Verify code let instr = memory.read_u32(0x1000).unwrap(); assert_eq!(instr, 0x02a00093); - + // Verify data assert_eq!(memory.read_u8(0x2000).unwrap(), 0x11); assert_eq!(memory.read_u8(0x2001).unwrap(), 0x22); assert_eq!(memory.read_u8(0x2002).unwrap(), 0x33); assert_eq!(memory.read_u8(0x2003).unwrap(), 0x44); - + // Verify BSS is zeroed for i in 0..bss_size { - assert_eq!(memory.read_u8(0x2004 + i).unwrap(), 0, "BSS byte {} not zero", i); + assert_eq!( + memory.read_u8(0x2004 + i).unwrap(), + 0, + "BSS byte {} not zero", + i + ); } } @@ -1064,7 +1197,7 @@ mod tests { elf[40..42].copy_from_slice(&52u16.to_le_bytes()); // ehsize elf[42..44].copy_from_slice(&32u16.to_le_bytes()); // phentsize elf[46..48].copy_from_slice(&40u16.to_le_bytes()); // shentsize - + let loader = ElfLoader::parse(&elf).unwrap(); let (low, high) = loader.memory_bounds(); assert_eq!(low, 0); @@ -1075,11 +1208,11 @@ mod tests { #[test] fn test_segment_flags() { use segment_flags::*; - + let code = vec![0x00; 4]; let elf_data = build_test_elf(&code, 0x1000, 0x1000); let loader = ElfLoader::parse(&elf_data).unwrap(); - + let seg = loader.loadable_segments().next().unwrap(); assert!(seg.p_flags & PF_R != 0, "Should be readable"); assert!(seg.p_flags & PF_X != 0, "Should be executable"); @@ -1099,7 +1232,7 @@ mod tests { let code = vec![0x00; 4]; let elf_data = build_test_elf(&code, 0x1000, 0x1000); let loader = ElfLoader::parse(&elf_data).unwrap(); - + let (rvc, flags) = loader.riscv_flags(); assert!(!rvc, "Test ELF has no RVC flag"); assert_eq!(flags, 0); @@ -1111,17 +1244,17 @@ mod tests { 0x00, 0x00, 0x00, 0x00, // nop (padding) 0x93, 0x00, 0xa0, 0x02, // addi x1, x0, 42 (entry point) ]; - + // Entry at offset 4 from load address let elf_data = build_test_elf(&code, 0x1004, 0x1000); let loader = ElfLoader::parse(&elf_data).unwrap(); - + assert_eq!(loader.entry_point(), 0x1004); - + let mut memory = Memory::with_default_size(); let entry = loader.load_into_memory(&mut memory).unwrap(); assert_eq!(entry, 0x1004); - + // Verify the actual entry point instruction let instr = memory.read_u32(0x1004).unwrap(); assert_eq!(instr, 0x02a00093); // addi x1, x0, 42 @@ -1133,10 +1266,10 @@ mod tests { let code: Vec = (0..1024).map(|i| (i % 256) as u8).collect(); let elf_data = build_test_elf(&code, 0x10000, 0x10000); let loader = ElfLoader::parse(&elf_data).unwrap(); - + let mut memory = Memory::with_default_size(); loader.load_into_memory(&mut memory).unwrap(); - + // Verify pattern for i in 0..1024u32 { let expected = (i % 256) as u8; @@ -1150,7 +1283,7 @@ mod tests { let code = vec![0x00; 8]; let elf_data = build_test_elf(&code, 0x80000000, 0x80000000); let loader = ElfLoader::parse(&elf_data).unwrap(); - + let header = loader.header(); assert_eq!(header.e_type, ET_EXEC); assert_eq!(header.e_machine, EM_RISCV); diff --git a/crates/executor/src/error.rs b/crates/executor/src/error.rs index d7f1ded..3fd3f99 100644 --- a/crates/executor/src/error.rs +++ b/crates/executor/src/error.rs @@ -12,10 +12,11 @@ pub enum ExecutorError { // === Unprovable Traps === // These errors indicate operations that cannot be proven in our constraint system. // Programs containing these will fail during proving. - /// ECALL instruction encountered - only specific syscalls are supported (0x1000=Keccak256, 93=exit). /// This is an unprovable trap that will cause prover failure. - #[error("Unprovable trap: ECALL (syscall {syscall_id}) at pc={pc:#x} - unsupported system call")] + #[error( + "Unprovable trap: ECALL (syscall {syscall_id}) at pc={pc:#x} - unsupported system call" + )] Ecall { pc: u32, syscall_id: u32 }, /// EBREAK instruction encountered - debug breakpoints not supported. @@ -36,10 +37,13 @@ pub enum ExecutorError { /// Unaligned memory access - words must be 4-byte aligned, halfwords 2-byte aligned. /// This is an unprovable trap that will cause prover failure. #[error("Unprovable trap: Unaligned {access_type} access at addr={addr:#x} (alignment required: {required} bytes)")] - UnalignedAccess { addr: u32, access_type: &'static str, required: u8 }, + UnalignedAccess { + addr: u32, + access_type: &'static str, + required: u8, + }, // === Execution Errors === - #[error("Invalid instruction at pc={pc:#x}: {bits:#010x}")] InvalidInstruction { pc: u32, bits: u32 }, @@ -53,7 +57,6 @@ pub enum ExecutorError { UnknownSyscall { pc: u32, syscall_code: u32 }, // === Normal Termination === - #[error("Execution halted: reached max steps ({max_steps})")] MaxStepsReached { max_steps: u64 }, diff --git a/crates/executor/src/lib.rs b/crates/executor/src/lib.rs index d5d6e3e..7363776 100644 --- a/crates/executor/src/lib.rs +++ b/crates/executor/src/lib.rs @@ -7,16 +7,16 @@ //! - Syscall/precompile hooks for delegation pub mod cpu; -pub mod memory; pub mod decode; -pub mod trace; -pub mod error; pub mod elf; +pub mod error; +pub mod memory; pub mod syscall; +pub mod trace; pub use cpu::Cpu; -pub use memory::Memory; -pub use trace::{ExecutionTrace, TraceRow}; -pub use error::ExecutorError; pub use elf::ElfLoader; +pub use error::ExecutorError; +pub use memory::Memory; pub use syscall::SyscallCode; +pub use trace::{ExecutionTrace, TraceRow}; diff --git a/crates/executor/src/memory.rs b/crates/executor/src/memory.rs index 52e3be3..d02afaf 100644 --- a/crates/executor/src/memory.rs +++ b/crates/executor/src/memory.rs @@ -63,14 +63,14 @@ impl Memory { } /// Read a halfword (16-bit) from memory (little-endian). - /// + /// /// # Errors /// Returns `UnalignedAccess` if addr is not 2-byte aligned (unprovable trap). #[inline] pub fn read_u16(&self, addr: u32) -> Result { if addr & 1 != 0 { - return Err(ExecutorError::UnalignedAccess { - addr, + return Err(ExecutorError::UnalignedAccess { + addr, access_type: "halfword read", required: 2, }); @@ -83,14 +83,14 @@ impl Memory { } /// Read a word (32-bit) from memory (little-endian). - /// + /// /// # Errors /// Returns `UnalignedAccess` if addr is not 4-byte aligned (unprovable trap). #[inline] pub fn read_u32(&self, addr: u32) -> Result { if addr & 3 != 0 { - return Err(ExecutorError::UnalignedAccess { - addr, + return Err(ExecutorError::UnalignedAccess { + addr, access_type: "word read", required: 4, }); @@ -119,14 +119,14 @@ impl Memory { } /// Write a halfword (16-bit) to memory (little-endian). - /// + /// /// # Errors /// Returns `UnalignedAccess` if addr is not 2-byte aligned (unprovable trap). #[inline] pub fn write_u16(&mut self, addr: u32, val: u16) -> Result<(), ExecutorError> { if addr & 1 != 0 { - return Err(ExecutorError::UnalignedAccess { - addr, + return Err(ExecutorError::UnalignedAccess { + addr, access_type: "halfword write", required: 2, }); @@ -142,14 +142,14 @@ impl Memory { } /// Write a word (32-bit) to memory (little-endian). - /// + /// /// # Errors /// Returns `UnalignedAccess` if addr is not 4-byte aligned (unprovable trap). #[inline] pub fn write_u32(&mut self, addr: u32, val: u32) -> Result<(), ExecutorError> { if addr & 3 != 0 { - return Err(ExecutorError::UnalignedAccess { - addr, + return Err(ExecutorError::UnalignedAccess { + addr, access_type: "word write", required: 4, }); diff --git a/crates/executor/src/syscall.rs b/crates/executor/src/syscall.rs index ecbde83..4c06c3e 100644 --- a/crates/executor/src/syscall.rs +++ b/crates/executor/src/syscall.rs @@ -37,73 +37,70 @@ #[repr(u32)] pub enum SyscallCode { // === I/O Syscalls === - /// Halt execution (normal program exit) /// - a0: exit code HALT = 0x00, - + /// Write data to host output /// - a0: file descriptor (1=stdout, 2=stderr) /// - a1: pointer to data /// - a2: length in bytes /// Returns: bytes written in a0 WRITE = 0x01, - + /// Read data from host input /// - a0: pointer to input buffer /// - a1: maximum bytes to read /// Returns: bytes read in a0 READ = 0x02, - + /// Commit value to public outputs (journal) /// - a0: pointer to data /// - a1: length in bytes COMMIT = 0x03, - + /// Hint data to prover (not verified, for optimization) /// - a0: pointer to hint data /// - a1: length in bytes HINT = 0x04, - + // === Cryptographic Hash Functions === - /// Keccak-256 hash /// - a0: input pointer /// - a1: input length /// - a2: output pointer (32 bytes) /// Returns: 0 on success KECCAK256 = 0x10, - + /// SHA-256 hash /// - a0: input pointer /// - a1: input length /// - a2: output pointer (32 bytes) /// Returns: 0 on success SHA256 = 0x11, - + /// RIPEMD-160 hash /// - a0: input pointer /// - a1: input length /// - a2: output pointer (20 bytes) /// Returns: 0 on success RIPEMD160 = 0x12, - + /// Blake2b hash (64-byte output) /// - a0: input pointer /// - a1: input length /// - a2: output pointer (64 bytes) /// Returns: 0 on success BLAKE2B = 0x13, - + /// Blake3 hash (32-byte output) /// - a0: input pointer /// - a1: input length /// - a2: output pointer (32 bytes) /// Returns: 0 on success BLAKE3 = 0x14, - + // === Ethereum Precompiles === - /// ECRECOVER signature recovery /// - a0: message hash pointer (32 bytes) /// - a1: v value (recovery id) @@ -112,7 +109,7 @@ pub enum SyscallCode { /// - a4: output pointer (20 bytes - address) /// Returns: 0 on success, 1 on invalid signature ECRECOVER = 0x20, - + /// Modular exponentiation (for RSA, Ethereum MODEXP precompile) /// - a0: base pointer (32 bytes) /// - a1: exponent pointer (32 bytes) @@ -120,46 +117,44 @@ pub enum SyscallCode { /// - a3: result pointer (32 bytes) /// Returns: 0 on success MODEXP = 0x21, - + // === Elliptic Curve Operations === - /// BN254 G1 point addition /// - a0: point A pointer (64 bytes: x, y) /// - a1: point B pointer (64 bytes: x, y) /// - a2: result pointer (64 bytes: x, y) /// Returns: 0 on success BN254_G1_ADD = 0x30, - + /// BN254 G1 scalar multiplication /// - a0: point pointer (64 bytes: x, y) /// - a1: scalar pointer (32 bytes) /// - a2: result pointer (64 bytes: x, y) /// Returns: 0 on success BN254_G1_MUL = 0x31, - + /// BN254 pairing check /// - a0: G1 points array pointer /// - a1: G2 points array pointer /// - a2: number of pairs /// Returns: 1 if pairing equals 1, 0 otherwise BN254_PAIRING = 0x32, - + /// BLS12-381 G1 point addition /// - a0: point A pointer (96 bytes: x, y) /// - a1: point B pointer (96 bytes: x, y) /// - a2: result pointer (96 bytes: x, y) /// Returns: 0 on success BLS12381_G1_ADD = 0x40, - + /// BLS12-381 aggregate signatures /// - a0: signatures array pointer /// - a1: number of signatures /// - a2: result pointer /// Returns: 0 on success BLS12381_AGGREGATE = 0x41, - + // === Ed25519 Signatures === - /// Ed25519 signature verification /// - a0: message pointer /// - a1: message length @@ -167,9 +162,8 @@ pub enum SyscallCode { /// - a3: public key pointer (32 bytes) /// Returns: 1 if valid, 0 if invalid ED25519_VERIFY = 0x50, - + // === Legacy Compatibility === - /// Linux exit syscall (for compatibility with standard RISC-V programs) /// - a0: exit code EXIT = 93, @@ -201,12 +195,12 @@ impl SyscallCode { _ => None, } } - + /// Get the syscall code as u32 pub fn as_u32(self) -> u32 { self as u32 } - + /// Get the name of the syscall pub fn name(self) -> &'static str { match self { @@ -236,7 +230,7 @@ impl SyscallCode { #[cfg(test)] mod tests { use super::*; - + #[test] fn test_syscall_roundtrip() { let codes = [ @@ -247,19 +241,19 @@ mod tests { SyscallCode::BN254_PAIRING, SyscallCode::EXIT, ]; - + for code in codes.iter() { let raw = code.as_u32(); let parsed = SyscallCode::from_u32(raw); assert_eq!(Some(*code), parsed); } } - + #[test] fn test_unknown_syscall() { assert_eq!(None, SyscallCode::from_u32(0xFFFF)); } - + #[test] fn test_syscall_names() { assert_eq!(SyscallCode::KECCAK256.name(), "KECCAK256"); diff --git a/crates/executor/src/trace.rs b/crates/executor/src/trace.rs index 72194f1..2804b05 100644 --- a/crates/executor/src/trace.rs +++ b/crates/executor/src/trace.rs @@ -3,8 +3,8 @@ //! Each step of execution produces a TraceRow capturing the CPU state, //! instruction, and any memory operations. -use serde::{Deserialize, Serialize}; use crate::decode::DecodedInstr; +use serde::{Deserialize, Serialize}; /// Memory operation type. #[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize)] @@ -24,17 +24,38 @@ pub enum MemOp { /// Store word (SW). StoreWord { addr: u32, value: u32 }, /// Keccak256 hash operation (delegated to specialized circuit). - Keccak256 { input_ptr: u32, input_len: u32, output_ptr: u32 }, + Keccak256 { + input_ptr: u32, + input_len: u32, + output_ptr: u32, + }, /// ECRECOVER signature verification (delegated to specialized circuit). Ecrecover { input_ptr: u32, output_ptr: u32 }, /// SHA-256 hash operation (delegated to specialized circuit). - Sha256 { message_ptr: usize, message_len: usize, digest_ptr: usize }, + Sha256 { + message_ptr: usize, + message_len: usize, + digest_ptr: usize, + }, /// RIPEMD-160 hash operation (delegated to specialized circuit). - Ripemd160 { message_ptr: usize, message_len: usize, digest_ptr: usize }, + Ripemd160 { + message_ptr: usize, + message_len: usize, + digest_ptr: usize, + }, /// Modular exponentiation (delegated to specialized circuit for RSA/crypto). - Modexp { base_ptr: usize, exp_ptr: usize, mod_ptr: usize, result_ptr: usize }, + Modexp { + base_ptr: usize, + exp_ptr: usize, + mod_ptr: usize, + result_ptr: usize, + }, /// Blake2b hash operation (delegated to specialized circuit). - Blake2b { message_ptr: usize, message_len: usize, digest_ptr: usize }, + Blake2b { + message_ptr: usize, + message_len: usize, + digest_ptr: usize, + }, } /// Flags indicating instruction class for AIR constraint selection. diff --git a/crates/executor/tests/test_blake2b.rs b/crates/executor/tests/test_blake2b.rs index 9039b3c..bccfc88 100644 --- a/crates/executor/tests/test_blake2b.rs +++ b/crates/executor/tests/test_blake2b.rs @@ -8,29 +8,30 @@ const BLAKE2B_SYSCALL: u32 = 0x1005; fn test_blake2b_syscall_empty() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr - cpu.set_reg(11, 0); // a1 = message_len (empty) - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(11, 0); // a1 = message_len (empty) + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, BLAKE2B_SYSCALL); // a7 = Blake2b syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (blake2b) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (blake2b) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until Blake2b syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -38,16 +39,17 @@ fn test_blake2b_syscall_empty() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 64).unwrap(); - + // Expected Blake2b-512 of empty string let expected = hex::decode( "786a02f742015903c6c6fd852552d272912f4740e15847618a86e217f71f5419\ - d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce" - ).unwrap(); - + d25e1031afee585313896444934eb04b903a685b1448b755d56f701afe9be2ce", + ) + .unwrap(); + assert_eq!(digest, expected, "Blake2b empty string mismatch"); assert_eq!(cpu.get_reg(10), 0, "Blake2b should return success"); } @@ -56,35 +58,36 @@ fn test_blake2b_syscall_empty() { fn test_blake2b_syscall_abc() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; let message = b"abc"; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall cpu.set_reg(10, input_ptr); cpu.set_reg(11, message.len() as u32); cpu.set_reg(12, output_ptr); cpu.set_reg(17, BLAKE2B_SYSCALL); - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (blake2b) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (blake2b) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until Blake2b syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -92,16 +95,17 @@ fn test_blake2b_syscall_abc() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 64).unwrap(); - + // Expected Blake2b-512 of "abc" let expected = hex::decode( "ba80a53f981c4d0d6a2797b69f12f6e94c212f14685ac4b74b12bb6fdbffa2d1\ - 7d87c5392aab792dc252d5de4533cc9518d38aa8dbf1925ab92386edd4009923" - ).unwrap(); - + 7d87c5392aab792dc252d5de4533cc9518d38aa8dbf1925ab92386edd4009923", + ) + .unwrap(); + assert_eq!(digest, expected, "Blake2b 'abc' mismatch"); assert_eq!(cpu.get_reg(10), 0, "Blake2b should return success"); } @@ -110,35 +114,36 @@ fn test_blake2b_syscall_abc() { fn test_blake2b_syscall_hello() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; let message = b"hello world"; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall cpu.set_reg(10, input_ptr); cpu.set_reg(11, message.len() as u32); cpu.set_reg(12, output_ptr); cpu.set_reg(17, BLAKE2B_SYSCALL); - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (blake2b) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (blake2b) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until Blake2b syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -146,11 +151,11 @@ fn test_blake2b_syscall_hello() { } let _ = cpu.step(); } - + // Read the digest - just verify it's 64 bytes and deterministic let digest = cpu.memory.slice(output_ptr, 64).unwrap(); assert_eq!(digest.len(), 64); - + // Run again to verify determinism let mut cpu2 = Cpu::new(); cpu2.enable_tracing(); @@ -169,7 +174,7 @@ fn test_blake2b_syscall_hello() { let _ = cpu2.step(); } let digest2 = cpu2.memory.slice(output_ptr, 64).unwrap(); - + assert_eq!(digest, digest2, "Blake2b should be deterministic"); assert_eq!(cpu.get_reg(10), 0, "Blake2b should return success"); } @@ -178,37 +183,38 @@ fn test_blake2b_syscall_hello() { fn test_blake2b_syscall_long() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x3000; - + // Create a longer message (1KB) let message = vec![0x42u8; 1024]; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall cpu.set_reg(10, input_ptr); cpu.set_reg(11, message.len() as u32); cpu.set_reg(12, output_ptr); cpu.set_reg(17, BLAKE2B_SYSCALL); - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (blake2b) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (blake2b) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until Blake2b syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -216,15 +222,15 @@ fn test_blake2b_syscall_long() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 64).unwrap(); assert_eq!(digest.len(), 64); - + // Verify it's not all zeros let is_nonzero = digest.iter().any(|&b| b != 0); assert!(is_nonzero, "Hash should not be all zeros"); - + assert_eq!(cpu.get_reg(10), 0, "Blake2b should return success"); } @@ -232,37 +238,38 @@ fn test_blake2b_syscall_long() { fn test_blake2b_syscall_zcash() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Simulate Zcash transaction data let message = b"zcash_transaction_example"; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall cpu.set_reg(10, input_ptr); cpu.set_reg(11, message.len() as u32); cpu.set_reg(12, output_ptr); cpu.set_reg(17, BLAKE2B_SYSCALL); - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (blake2b) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (blake2b) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until Blake2b syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -270,12 +277,12 @@ fn test_blake2b_syscall_zcash() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 64).unwrap(); - + // Should produce valid 64-byte hash for Zcash compatibility assert_eq!(digest.len(), 64); - + assert_eq!(cpu.get_reg(10), 0, "Blake2b should return success"); } diff --git a/crates/executor/tests/test_ecrecover.rs b/crates/executor/tests/test_ecrecover.rs index f03157a..9844242 100644 --- a/crates/executor/tests/test_ecrecover.rs +++ b/crates/executor/tests/test_ecrecover.rs @@ -10,54 +10,55 @@ fn test_ecrecover_syscall_valid() { let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Create a test signature (we'll use generated data) // For a real test, we'd use a known Ethereum signature - use secp256k1::{Secp256k1, Message, SecretKey}; use secp256k1::ecdsa::RecoverableSignature; - + use secp256k1::{Message, Secp256k1, SecretKey}; + let secp = Secp256k1::new(); let secret_key = SecretKey::from_slice(&[0xCD; 32]).unwrap(); let public_key = secp256k1::PublicKey::from_secret_key(&secp, &secret_key); - + // Create message and sign it let message = [0xAA; 32]; let msg = Message::from_digest_slice(&message).unwrap(); let sig: RecoverableSignature = secp.sign_ecdsa_recoverable(&msg, &secret_key); - + let (recovery_id, sig_bytes) = sig.serialize_compact(); let v = recovery_id.to_i32() as u8; - + // Prepare input: hash || v || r || s let mut input = vec![0u8; 97]; input[0..32].copy_from_slice(&message); input[32] = v; input[33..65].copy_from_slice(&sig_bytes[0..32]); // r input[65..97].copy_from_slice(&sig_bytes[32..64]); // s - + // Write input to memory for (i, &byte) in input.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = input_ptr - cpu.set_reg(11, output_ptr); // a1 = output_ptr - cpu.set_reg(17, 0x1001); // a7 = ecrecover syscall number - + cpu.set_reg(10, input_ptr); // a0 = input_ptr + cpu.set_reg(11, output_ptr); // a1 = output_ptr + cpu.set_reg(17, 0x1001); // a7 = ecrecover syscall number + // Create a simple program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (ecrecover) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (ecrecover) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - - let program_bytes: Vec = program.iter() + + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); - + cpu.load_program(0, &program_bytes).unwrap(); - + // Run until exit let mut steps = 0; let max_steps = 100; @@ -80,65 +81,83 @@ fn test_ecrecover_syscall_valid() { } } } - + // Verify the output address is non-zero (valid signature) let mut output = [0u8; 20]; for i in 0..20 { output[i] = cpu.memory.read_byte(output_ptr + i as u32).unwrap(); } - + // Should recover to a valid address (not all zeros) - assert_ne!(&output[..], &[0u8; 20], "ECRECOVER should recover a valid address"); - + assert_ne!( + &output[..], + &[0u8; 20], + "ECRECOVER should recover a valid address" + ); + // Compute expected address manually let pubkey_bytes = public_key.serialize_uncompressed(); let hash = zp1_delegation::keccak::keccak256(&pubkey_bytes[1..]); let expected_address = &hash[12..]; - - assert_eq!(&output[..], expected_address, - "ECRECOVER address mismatch!\nExpected: {:?}\nGot: {:?}", - expected_address, output); - + + assert_eq!( + &output[..], + expected_address, + "ECRECOVER address mismatch!\nExpected: {:?}\nGot: {:?}", + expected_address, + output + ); + // Verify the trace contains the ECRECOVER delegation let trace = cpu.take_trace().unwrap(); - let ecrecover_ops: Vec<_> = trace.rows.iter() + let ecrecover_ops: Vec<_> = trace + .rows + .iter() .filter(|row| matches!(row.mem_op, zp1_executor::trace::MemOp::Ecrecover { .. })) .collect(); - - assert_eq!(ecrecover_ops.len(), 1, "Should have exactly one ECRECOVER operation in trace"); + + assert_eq!( + ecrecover_ops.len(), + 1, + "Should have exactly one ECRECOVER operation in trace" + ); } /// Test ECRECOVER with invalid signature (should return zero address). #[test] fn test_ecrecover_invalid_signature() { let mut cpu = Cpu::new(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Create invalid signature (all zeros) let input = vec![0u8; 97]; - + for (i, &byte) in input.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + cpu.set_reg(10, input_ptr); cpu.set_reg(11, output_ptr); cpu.set_reg(17, 0x1001); - + cpu.load_program(0, &0x00000073u32.to_le_bytes()).unwrap(); - + let _ = cpu.step(); - + // Verify output is zero address (invalid signature) let mut output = [0u8; 20]; for i in 0..20 { output[i] = cpu.memory.read_byte(output_ptr + i as u32).unwrap(); } - - assert_eq!(&output[..], &[0u8; 20], "Invalid signature should produce zero address"); - + + assert_eq!( + &output[..], + &[0u8; 20], + "Invalid signature should produce zero address" + ); + // Check return value (a0 should be 1 for failure) assert_eq!(cpu.get_reg(10), 1, "Return value should indicate failure"); } @@ -147,34 +166,34 @@ fn test_ecrecover_invalid_signature() { #[test] fn test_ecrecover_invalid_v() { let mut cpu = Cpu::new(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Create signature with invalid v (99) let mut input = vec![0u8; 97]; input[32] = 99; // Invalid v input[33] = 0x01; // Some non-zero r input[65] = 0x01; // Some non-zero s - + for (i, &byte) in input.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + cpu.set_reg(10, input_ptr); cpu.set_reg(11, output_ptr); cpu.set_reg(17, 0x1001); - + cpu.load_program(0, &0x00000073u32.to_le_bytes()).unwrap(); - + let _ = cpu.step(); - + // Verify zero address let mut output = [0u8; 20]; for i in 0..20 { output[i] = cpu.memory.read_byte(output_ptr + i as u32).unwrap(); } - + assert_eq!(&output[..], &[0u8; 20]); assert_eq!(cpu.get_reg(10), 1); // Failure } @@ -183,49 +202,49 @@ fn test_ecrecover_invalid_v() { #[test] fn test_ecrecover_eip155() { let mut cpu = Cpu::new(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Create a valid signature with EIP-155 v (chainId=1, v=37 or 38) - use secp256k1::{Secp256k1, Message, SecretKey}; - + use secp256k1::{Message, Secp256k1, SecretKey}; + let secp = Secp256k1::new(); let secret_key = SecretKey::from_slice(&[0xAB; 32]).unwrap(); let message = [0xBB; 32]; let msg = Message::from_digest_slice(&message).unwrap(); let sig = secp.sign_ecdsa_recoverable(&msg, &secret_key); - + let (recovery_id, sig_bytes) = sig.serialize_compact(); - + // Convert to EIP-155: v = chainId * 2 + 35 + recovery_id let chain_id = 1u32; let v = (chain_id * 2 + 35 + recovery_id.to_i32() as u32) as u8; - + let mut input = vec![0u8; 97]; input[0..32].copy_from_slice(&message); input[32] = v; input[33..65].copy_from_slice(&sig_bytes[0..32]); input[65..97].copy_from_slice(&sig_bytes[32..64]); - + for (i, &byte) in input.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + cpu.set_reg(10, input_ptr); cpu.set_reg(11, output_ptr); cpu.set_reg(17, 0x1001); - + cpu.load_program(0, &0x00000073u32.to_le_bytes()).unwrap(); - + let _ = cpu.step(); - + // Verify valid address recovered let mut output = [0u8; 20]; for i in 0..20 { output[i] = cpu.memory.read_byte(output_ptr + i as u32).unwrap(); } - + assert_ne!(&output[..], &[0u8; 20], "EIP-155 signature should be valid"); assert_eq!(cpu.get_reg(10), 0); // Success } diff --git a/crates/executor/tests/test_keccak.rs b/crates/executor/tests/test_keccak.rs index 502667b..7f68dbc 100644 --- a/crates/executor/tests/test_keccak.rs +++ b/crates/executor/tests/test_keccak.rs @@ -11,37 +11,38 @@ fn test_keccak256_syscall() { // Allocate memory for input and output let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Test vector: "hello" -> hash let input = b"hello"; - let expected_hash = hex::decode("1c8aff950685c2ed4bc3174f3472287b56d9517b9c948127319a09a7a36deac8") - .unwrap(); - + let expected_hash = + hex::decode("1c8aff950685c2ed4bc3174f3472287b56d9517b9c948127319a09a7a36deac8").unwrap(); + // Write input to memory for (i, &byte) in input.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers directly for the syscall - cpu.set_reg(10, input_ptr); // a0 = input_ptr - cpu.set_reg(11, 5); // a1 = input_len - cpu.set_reg(12, output_ptr); // a2 = output_ptr - cpu.set_reg(17, 0x1000); // a7 = keccak syscall number - + cpu.set_reg(10, input_ptr); // a0 = input_ptr + cpu.set_reg(11, 5); // a1 = input_len + cpu.set_reg(12, output_ptr); // a2 = output_ptr + cpu.set_reg(17, 0x1000); // a7 = keccak syscall number + // Create a simple program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (keccak) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (keccak) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program at address 0 - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); - + cpu.load_program(0, &program_bytes).unwrap(); - + // Run until exit let mut steps = 0; let max_steps = 100; @@ -65,56 +66,65 @@ fn test_keccak256_syscall() { } } } - + // Verify the output let mut output = [0u8; 32]; for i in 0..32 { output[i] = cpu.memory.read_byte(output_ptr + i as u32).unwrap(); } - - assert_eq!(&output[..], &expected_hash[..], - "Keccak256 hash mismatch!\nExpected: {:?}\nGot: {:?}", - expected_hash, output); - + + assert_eq!( + &output[..], + &expected_hash[..], + "Keccak256 hash mismatch!\nExpected: {:?}\nGot: {:?}", + expected_hash, + output + ); + // Verify the trace contains the Keccak delegation let trace = cpu.take_trace().unwrap(); - let keccak_ops: Vec<_> = trace.rows.iter() + let keccak_ops: Vec<_> = trace + .rows + .iter() .filter(|row| matches!(row.mem_op, zp1_executor::trace::MemOp::Keccak256 { .. })) .collect(); - - assert_eq!(keccak_ops.len(), 1, "Should have exactly one Keccak256 operation in trace"); + + assert_eq!( + keccak_ops.len(), + 1, + "Should have exactly one Keccak256 operation in trace" + ); } /// Test Keccak256 with empty input. #[test] fn test_keccak256_empty() { let mut cpu = Cpu::new(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - - let expected_hash = hex::decode("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470") - .unwrap(); - + + let expected_hash = + hex::decode("c5d2460186f7233c927e7db2dcc703c0e500b653ca82273b7bfad8045d85a470").unwrap(); + // Load registers directly cpu.set_reg(10, input_ptr); cpu.set_reg(11, 0); // length = 0 cpu.set_reg(12, output_ptr); cpu.set_reg(17, 0x1000); // a7 = keccak syscall - + // Create minimal program: ecall cpu.load_program(0, &0x00000073u32.to_le_bytes()).unwrap(); - + // Execute the syscall match cpu.step() { - Err(zp1_executor::ExecutorError::Ecall { syscall_id: 93, .. }) | - Ok(_) => { + Err(zp1_executor::ExecutorError::Ecall { syscall_id: 93, .. }) | Ok(_) => { // Check output let mut output = [0u8; 32]; for i in 0..32 { output[i] = cpu.memory.read_byte(output_ptr + i as u32).unwrap(); } - + assert_eq!(&output[..], &expected_hash[..]); } other => panic!("Unexpected result: {:?}", other), @@ -125,36 +135,36 @@ fn test_keccak256_empty() { #[test] fn test_keccak256_long_input() { let mut cpu = Cpu::new(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + // 200 bytes of input (requires 2 absorption rounds) let input = vec![0x42u8; 200]; let expected_hash = zp1_delegation::keccak::keccak256(&input); - + // Write input to memory for (i, &byte) in input.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Load registers cpu.set_reg(10, input_ptr); cpu.set_reg(11, 200); cpu.set_reg(12, output_ptr); cpu.set_reg(17, 0x1000); - + // Create program: ecall cpu.load_program(0, &0x00000073u32.to_le_bytes()).unwrap(); - + // Execute let _ = cpu.step(); - + // Verify output let mut output = [0u8; 32]; for i in 0..32 { output[i] = cpu.memory.read_byte(output_ptr + i as u32).unwrap(); } - + assert_eq!(&output[..], &expected_hash[..]); } diff --git a/crates/executor/tests/test_modexp.rs b/crates/executor/tests/test_modexp.rs index 2787bdc..8f7fadf 100644 --- a/crates/executor/tests/test_modexp.rs +++ b/crates/executor/tests/test_modexp.rs @@ -8,13 +8,13 @@ const MODEXP_SYSCALL: u32 = 0x1004; fn test_modexp_syscall_simple() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + // Test: 2^3 mod 5 = 8 mod 5 = 3 let base_ptr = 0x1000; let exp_ptr = 0x1020; let mod_ptr = 0x1040; let result_ptr = 0x2000; - + // Write base = 2 (little-endian, 32 bytes) let base_bytes: [u8; 32] = { let mut b = [0u8; 32]; @@ -24,7 +24,7 @@ fn test_modexp_syscall_simple() { for (i, &byte) in base_bytes.iter().enumerate() { cpu.memory.write_byte(base_ptr + i as u32, byte).unwrap(); } - + // Write exponent = 3 let exp_bytes: [u8; 32] = { let mut e = [0u8; 32]; @@ -34,7 +34,7 @@ fn test_modexp_syscall_simple() { for (i, &byte) in exp_bytes.iter().enumerate() { cpu.memory.write_byte(exp_ptr + i as u32, byte).unwrap(); } - + // Write modulus = 5 let mod_bytes: [u8; 32] = { let mut m = [0u8; 32]; @@ -44,27 +44,28 @@ fn test_modexp_syscall_simple() { for (i, &byte) in mod_bytes.iter().enumerate() { cpu.memory.write_byte(mod_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall - cpu.set_reg(10, base_ptr); // a0 = base_ptr - cpu.set_reg(11, exp_ptr); // a1 = exp_ptr - cpu.set_reg(12, mod_ptr); // a2 = mod_ptr - cpu.set_reg(13, result_ptr); // a3 = result_ptr + cpu.set_reg(10, base_ptr); // a0 = base_ptr + cpu.set_reg(11, exp_ptr); // a1 = exp_ptr + cpu.set_reg(12, mod_ptr); // a2 = mod_ptr + cpu.set_reg(13, result_ptr); // a3 = result_ptr cpu.set_reg(17, MODEXP_SYSCALL); // a7 = MODEXP syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (modexp) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (modexp) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until MODEXP syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -72,16 +73,16 @@ fn test_modexp_syscall_simple() { } let _ = cpu.step(); } - + // Read the result let result = cpu.memory.slice(result_ptr, 32).unwrap(); - + // Expected: 3 (2^3 mod 5 = 8 mod 5 = 3) assert_eq!(result[0], 3); for i in 1..32 { assert_eq!(result[i], 0); } - + assert_eq!(cpu.get_reg(10), 0, "MODEXP should return success"); } @@ -89,13 +90,13 @@ fn test_modexp_syscall_simple() { fn test_modexp_syscall_zero_exponent() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + // Test: 5^0 mod 7 = 1 (any number to power 0 is 1) let base_ptr = 0x1000; let exp_ptr = 0x1020; let mod_ptr = 0x1040; let result_ptr = 0x2000; - + // Write base = 5 let base_bytes: [u8; 32] = { let mut b = [0u8; 32]; @@ -105,13 +106,13 @@ fn test_modexp_syscall_zero_exponent() { for (i, &byte) in base_bytes.iter().enumerate() { cpu.memory.write_byte(base_ptr + i as u32, byte).unwrap(); } - + // Write exponent = 0 let exp_bytes: [u8; 32] = [0u8; 32]; for (i, &byte) in exp_bytes.iter().enumerate() { cpu.memory.write_byte(exp_ptr + i as u32, byte).unwrap(); } - + // Write modulus = 7 let mod_bytes: [u8; 32] = { let mut m = [0u8; 32]; @@ -121,27 +122,28 @@ fn test_modexp_syscall_zero_exponent() { for (i, &byte) in mod_bytes.iter().enumerate() { cpu.memory.write_byte(mod_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall cpu.set_reg(10, base_ptr); cpu.set_reg(11, exp_ptr); cpu.set_reg(12, mod_ptr); cpu.set_reg(13, result_ptr); cpu.set_reg(17, MODEXP_SYSCALL); - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (modexp) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (modexp) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until MODEXP syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -149,16 +151,16 @@ fn test_modexp_syscall_zero_exponent() { } let _ = cpu.step(); } - + // Read the result let result = cpu.memory.slice(result_ptr, 32).unwrap(); - + // Expected: 1 (x^0 = 1 for any x) assert_eq!(result[0], 1); for i in 1..32 { assert_eq!(result[i], 0); } - + assert_eq!(cpu.get_reg(10), 0, "MODEXP should return success"); } @@ -166,7 +168,7 @@ fn test_modexp_syscall_zero_exponent() { fn test_modexp_syscall_rsa_small() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + // Small RSA example: message^e mod n // m = 42, e = 17, n = 3233 (= 61 * 53) // c = 42^17 mod 3233 = 2557 @@ -174,7 +176,7 @@ fn test_modexp_syscall_rsa_small() { let exp_ptr = 0x1020; let mod_ptr = 0x1040; let result_ptr = 0x2000; - + // Write base = 42 let base_bytes: [u8; 32] = { let mut b = [0u8; 32]; @@ -184,7 +186,7 @@ fn test_modexp_syscall_rsa_small() { for (i, &byte) in base_bytes.iter().enumerate() { cpu.memory.write_byte(base_ptr + i as u32, byte).unwrap(); } - + // Write exponent = 17 let exp_bytes: [u8; 32] = { let mut e = [0u8; 32]; @@ -194,7 +196,7 @@ fn test_modexp_syscall_rsa_small() { for (i, &byte) in exp_bytes.iter().enumerate() { cpu.memory.write_byte(exp_ptr + i as u32, byte).unwrap(); } - + // Write modulus = 3233 (0x0CA1) let mod_bytes: [u8; 32] = { let mut m = [0u8; 32]; @@ -205,27 +207,28 @@ fn test_modexp_syscall_rsa_small() { for (i, &byte) in mod_bytes.iter().enumerate() { cpu.memory.write_byte(mod_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall cpu.set_reg(10, base_ptr); cpu.set_reg(11, exp_ptr); cpu.set_reg(12, mod_ptr); cpu.set_reg(13, result_ptr); cpu.set_reg(17, MODEXP_SYSCALL); - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (modexp) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (modexp) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until MODEXP syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -233,10 +236,10 @@ fn test_modexp_syscall_rsa_small() { } let _ = cpu.step(); } - + // Read the result let result = cpu.memory.slice(result_ptr, 32).unwrap(); - + // Expected: 2557 (0x09FD) let expected = 2557u16; let expected_bytes = expected.to_le_bytes(); @@ -245,7 +248,7 @@ fn test_modexp_syscall_rsa_small() { for i in 2..32 { assert_eq!(result[i], 0); } - + assert_eq!(cpu.get_reg(10), 0, "MODEXP should return success"); } @@ -253,7 +256,7 @@ fn test_modexp_syscall_rsa_small() { fn test_modexp_syscall_large_numbers() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + // Test with larger numbers to verify U256 handling // base = 0x123456789ABCDEF (56 bits) // exp = 0x3 @@ -262,7 +265,7 @@ fn test_modexp_syscall_large_numbers() { let exp_ptr = 0x1020; let mod_ptr = 0x1040; let result_ptr = 0x2000; - + // Write base let base_bytes: [u8; 32] = { let mut b = [0u8; 32]; @@ -279,7 +282,7 @@ fn test_modexp_syscall_large_numbers() { for (i, &byte) in base_bytes.iter().enumerate() { cpu.memory.write_byte(base_ptr + i as u32, byte).unwrap(); } - + // Write exponent = 3 let exp_bytes: [u8; 32] = { let mut e = [0u8; 32]; @@ -289,7 +292,7 @@ fn test_modexp_syscall_large_numbers() { for (i, &byte) in exp_bytes.iter().enumerate() { cpu.memory.write_byte(exp_ptr + i as u32, byte).unwrap(); } - + // Write modulus let mod_bytes: [u8; 32] = { let mut m = [0u8; 32]; @@ -306,27 +309,28 @@ fn test_modexp_syscall_large_numbers() { for (i, &byte) in mod_bytes.iter().enumerate() { cpu.memory.write_byte(mod_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall cpu.set_reg(10, base_ptr); cpu.set_reg(11, exp_ptr); cpu.set_reg(12, mod_ptr); cpu.set_reg(13, result_ptr); cpu.set_reg(17, MODEXP_SYSCALL); - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (modexp) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (modexp) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until MODEXP syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -334,10 +338,10 @@ fn test_modexp_syscall_large_numbers() { } let _ = cpu.step(); } - + // Read the result - just verify it computed something valid let result = cpu.memory.slice(result_ptr, 32).unwrap(); - + // Result should be less than modulus and non-zero let mut is_zero = true; for &byte in result.iter() { @@ -347,6 +351,6 @@ fn test_modexp_syscall_large_numbers() { } } assert!(!is_zero, "Result should be non-zero"); - + assert_eq!(cpu.get_reg(10), 0, "MODEXP should return success"); } diff --git a/crates/executor/tests/test_ripemd160.rs b/crates/executor/tests/test_ripemd160.rs index 39c95d9..d300c7e 100644 --- a/crates/executor/tests/test_ripemd160.rs +++ b/crates/executor/tests/test_ripemd160.rs @@ -8,48 +8,49 @@ const RIPEMD160_SYSCALL: u32 = 0x1003; fn test_ripemd160_syscall_empty() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; let message = b""; - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr - cpu.set_reg(11, 0); // a1 = message_len - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(11, 0); // a1 = message_len + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, RIPEMD160_SYSCALL); // a7 = RIPEMD-160 syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (ripemd160) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (ripemd160) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until RIPEMD-160 syscall completes for _ in 0..5 { - if cpu.pc == 4 { // After first ecall + if cpu.pc == 4 { + // After first ecall break; } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 20).unwrap(); - + // Expected: 9c1185a5c5e9fc54612808977ee8f548b2258d31 let expected = [ - 0x9c, 0x11, 0x85, 0xa5, 0xc5, 0xe9, 0xfc, 0x54, - 0x61, 0x28, 0x08, 0x97, 0x7e, 0xe8, 0xf5, 0x48, - 0xb2, 0x25, 0x8d, 0x31, + 0x9c, 0x11, 0x85, 0xa5, 0xc5, 0xe9, 0xfc, 0x54, 0x61, 0x28, 0x08, 0x97, 0x7e, 0xe8, 0xf5, + 0x48, 0xb2, 0x25, 0x8d, 0x31, ]; - + assert_eq!(digest, expected, "RIPEMD-160 empty string mismatch"); assert_eq!(cpu.get_reg(10), 0, "RIPEMD-160 should return success"); } @@ -58,35 +59,36 @@ fn test_ripemd160_syscall_empty() { fn test_ripemd160_syscall_hello() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; let message = b"hello world"; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr cpu.set_reg(11, message.len() as u32); // a1 = message_len - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, RIPEMD160_SYSCALL); // a7 = RIPEMD-160 syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (ripemd160) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (ripemd160) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until RIPEMD-160 syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -94,17 +96,16 @@ fn test_ripemd160_syscall_hello() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 20).unwrap(); - + // Expected: 98c615784ccb5fe5936fbc0cbe9dfdb408d92f0f let expected = [ - 0x98, 0xc6, 0x15, 0x78, 0x4c, 0xcb, 0x5f, 0xe5, - 0x93, 0x6f, 0xbc, 0x0c, 0xbe, 0x9d, 0xfd, 0xb4, - 0x08, 0xd9, 0x2f, 0x0f, + 0x98, 0xc6, 0x15, 0x78, 0x4c, 0xcb, 0x5f, 0xe5, 0x93, 0x6f, 0xbc, 0x0c, 0xbe, 0x9d, 0xfd, + 0xb4, 0x08, 0xd9, 0x2f, 0x0f, ]; - + assert_eq!(digest, expected, "RIPEMD-160 'hello world' mismatch"); assert_eq!(cpu.get_reg(10), 0, "RIPEMD-160 should return success"); } @@ -113,35 +114,36 @@ fn test_ripemd160_syscall_hello() { fn test_ripemd160_syscall_abc() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; let message = b"abc"; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr cpu.set_reg(11, message.len() as u32); // a1 = message_len - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, RIPEMD160_SYSCALL); // a7 = RIPEMD-160 syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (ripemd160) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (ripemd160) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until RIPEMD-160 syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -149,17 +151,16 @@ fn test_ripemd160_syscall_abc() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 20).unwrap(); - + // Expected: 8eb208f7e05d987a9b044a8e98c6b087f15a0bfc let expected = [ - 0x8e, 0xb2, 0x08, 0xf7, 0xe0, 0x5d, 0x98, 0x7a, - 0x9b, 0x04, 0x4a, 0x8e, 0x98, 0xc6, 0xb0, 0x87, - 0xf1, 0x5a, 0x0b, 0xfc, + 0x8e, 0xb2, 0x08, 0xf7, 0xe0, 0x5d, 0x98, 0x7a, 0x9b, 0x04, 0x4a, 0x8e, 0x98, 0xc6, 0xb0, + 0x87, 0xf1, 0x5a, 0x0b, 0xfc, ]; - + assert_eq!(digest, expected, "RIPEMD-160 'abc' mismatch"); assert_eq!(cpu.get_reg(10), 0, "RIPEMD-160 should return success"); } @@ -168,43 +169,43 @@ fn test_ripemd160_syscall_abc() { fn test_ripemd160_syscall_bitcoin_address() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; - + // Simulate a Bitcoin address generation: RIPEMD-160(SHA-256(pubkey)) // This test uses a known SHA-256 output as input let sha256_output: [u8; 32] = [ - 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, - 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, - 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, - 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, 0xff, + 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, 0xee, + 0xff, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xaa, 0xbb, 0xcc, 0xdd, + 0xee, 0xff, ]; - + // Write SHA-256 output to memory for (i, &byte) in sha256_output.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr - cpu.set_reg(11, 32); // a1 = message_len (32 bytes from SHA-256) - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(11, 32); // a1 = message_len (32 bytes from SHA-256) + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, RIPEMD160_SYSCALL); // a7 = RIPEMD-160 syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (ripemd160) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (ripemd160) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until RIPEMD-160 syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -212,13 +213,13 @@ fn test_ripemd160_syscall_bitcoin_address() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 20).unwrap(); - + // Verify it produces a valid 20-byte output assert_eq!(digest.len(), 20); - + // Verify success assert_eq!(cpu.get_reg(10), 0, "RIPEMD-160 should return success"); } diff --git a/crates/executor/tests/test_sha256.rs b/crates/executor/tests/test_sha256.rs index b6697da..218f6df 100644 --- a/crates/executor/tests/test_sha256.rs +++ b/crates/executor/tests/test_sha256.rs @@ -8,54 +8,55 @@ const SHA256_SYSCALL: u32 = 0x1002; fn test_sha256_syscall_empty() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; let message = b""; - + // Write message to memory (empty in this case) // (no writes needed for empty message) - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr - cpu.set_reg(11, 0); // a1 = message_len - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(11, 0); // a1 = message_len + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, SHA256_SYSCALL); // a7 = SHA256 syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (sha256) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (sha256) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until SHA-256 syscall completes for _ in 0..5 { - if cpu.pc == 4 { // After first ecall + if cpu.pc == 4 { + // After first ecall break; } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 32).unwrap(); - + // Expected: e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 let expected = [ - 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, - 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, 0x24, - 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, - 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, 0xb8, 0x55, + 0xe3, 0xb0, 0xc4, 0x42, 0x98, 0xfc, 0x1c, 0x14, 0x9a, 0xfb, 0xf4, 0xc8, 0x99, 0x6f, 0xb9, + 0x24, 0x27, 0xae, 0x41, 0xe4, 0x64, 0x9b, 0x93, 0x4c, 0xa4, 0x95, 0x99, 0x1b, 0x78, 0x52, + 0xb8, 0x55, ]; - + assert_eq!(digest, expected, "SHA-256 empty string mismatch"); - + // Check return value (a0 should be 0 for success) assert_eq!(cpu.get_reg(10), 0, "SHA-256 should return success"); } @@ -64,35 +65,36 @@ fn test_sha256_syscall_empty() { fn test_sha256_syscall_hello() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; let message = b"hello world"; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr cpu.set_reg(11, message.len() as u32); // a1 = message_len - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, SHA256_SYSCALL); // a7 = SHA256 syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (sha256) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (sha256) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until SHA-256 syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -100,18 +102,17 @@ fn test_sha256_syscall_hello() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 32).unwrap(); - + // Expected: b94d27b9934d3e08a52e52d7da7dabfac484efe37a5380ee9088f7ace2efcde9 let expected = [ - 0xb9, 0x4d, 0x27, 0xb9, 0x93, 0x4d, 0x3e, 0x08, - 0xa5, 0x2e, 0x52, 0xd7, 0xda, 0x7d, 0xab, 0xfa, - 0xc4, 0x84, 0xef, 0xe3, 0x7a, 0x53, 0x80, 0xee, - 0x90, 0x88, 0xf7, 0xac, 0xe2, 0xef, 0xcd, 0xe9, + 0xb9, 0x4d, 0x27, 0xb9, 0x93, 0x4d, 0x3e, 0x08, 0xa5, 0x2e, 0x52, 0xd7, 0xda, 0x7d, 0xab, + 0xfa, 0xc4, 0x84, 0xef, 0xe3, 0x7a, 0x53, 0x80, 0xee, 0x90, 0x88, 0xf7, 0xac, 0xe2, 0xef, + 0xcd, 0xe9, ]; - + assert_eq!(digest, expected, "SHA-256 'hello world' mismatch"); assert_eq!(cpu.get_reg(10), 0, "SHA-256 should return success"); } @@ -120,35 +121,36 @@ fn test_sha256_syscall_hello() { fn test_sha256_syscall_abc() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; let message = b"abc"; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr cpu.set_reg(11, message.len() as u32); // a1 = message_len - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, SHA256_SYSCALL); // a7 = SHA256 syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (sha256) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (sha256) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until SHA-256 syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -156,18 +158,17 @@ fn test_sha256_syscall_abc() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 32).unwrap(); - + // Expected: ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad let expected = [ - 0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea, - 0x41, 0x41, 0x40, 0xde, 0x5d, 0xae, 0x22, 0x23, - 0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17, 0x7a, 0x9c, - 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, 0x15, 0xad, + 0xba, 0x78, 0x16, 0xbf, 0x8f, 0x01, 0xcf, 0xea, 0x41, 0x41, 0x40, 0xde, 0x5d, 0xae, 0x22, + 0x23, 0xb0, 0x03, 0x61, 0xa3, 0x96, 0x17, 0x7a, 0x9c, 0xb4, 0x10, 0xff, 0x61, 0xf2, 0x00, + 0x15, 0xad, ]; - + assert_eq!(digest, expected, "SHA-256 'abc' mismatch"); assert_eq!(cpu.get_reg(10), 0, "SHA-256 should return success"); } @@ -176,36 +177,37 @@ fn test_sha256_syscall_abc() { fn test_sha256_syscall_long_message() { let mut cpu = Cpu::new(); cpu.enable_tracing(); - + let input_ptr = 0x1000; let output_ptr = 0x2000; // Test with a message longer than 64 bytes to exercise multiple blocks let message = b"The quick brown fox jumps over the lazy dog. The quick brown fox jumps over the lazy dog."; - + // Write message to memory for (i, &byte) in message.iter().enumerate() { cpu.memory.write_byte(input_ptr + i as u32, byte).unwrap(); } - + // Set up registers for the syscall - cpu.set_reg(10, input_ptr); // a0 = message_ptr + cpu.set_reg(10, input_ptr); // a0 = message_ptr cpu.set_reg(11, message.len() as u32); // a1 = message_len - cpu.set_reg(12, output_ptr); // a2 = digest_ptr + cpu.set_reg(12, output_ptr); // a2 = digest_ptr cpu.set_reg(17, SHA256_SYSCALL); // a7 = SHA256 syscall number - + // Create program: ecall, then exit ecall let program: Vec = vec![ - 0x00000073, // ecall (sha256) - 0x05d00893, // li a7, 93 (exit syscall) - 0x00000073, // ecall (exit) + 0x00000073, // ecall (sha256) + 0x05d00893, // li a7, 93 (exit syscall) + 0x00000073, // ecall (exit) ]; - + // Load program - let program_bytes: Vec = program.iter() + let program_bytes: Vec = program + .iter() .flat_map(|instr| instr.to_le_bytes()) .collect(); cpu.memory.load_program(0, &program_bytes).unwrap(); - + // Run until SHA-256 syscall completes for _ in 0..5 { if cpu.pc == 4 { @@ -213,10 +215,10 @@ fn test_sha256_syscall_long_message() { } let _ = cpu.step(); } - + // Read the digest let digest = cpu.memory.slice(output_ptr, 32).unwrap(); - + // Verify it's deterministic let expected_digest = zp1_delegation::sha256::sha256(message); assert_eq!(digest, expected_digest, "SHA-256 long message mismatch"); diff --git a/crates/trace/src/columns.rs b/crates/trace/src/columns.rs index a44f112..6ac2b52 100644 --- a/crates/trace/src/columns.rs +++ b/crates/trace/src/columns.rs @@ -72,8 +72,8 @@ //! assert_eq!(column_vec.len(), 77); //! ``` -use zp1_primitives::M31; use zp1_executor::ExecutionTrace; +use zp1_primitives::M31; /// Number of columns in the CPU trace. pub const NUM_CPU_COLUMNS: usize = 77; @@ -193,7 +193,7 @@ pub struct TraceColumns { /// Multiply intermediate (64-bit product). pub mul_lo: Vec, pub mul_hi: Vec, - + // Auxiliary witnesses pub carry: Vec, pub borrow: Vec, @@ -204,22 +204,21 @@ pub struct TraceColumns { pub lt_result: Vec, pub eq_result: Vec, pub branch_taken: Vec, - + // Bitwise operation bit decompositions (32 bits each) // INPUT decompositions (needed for proper constraint verification) - pub rs1_bits: [Vec; 32], // Bit decomposition of rs1 - pub rs2_bits: [Vec; 32], // Bit decomposition of rs2 - pub imm_bits: [Vec; 32], // Bit decomposition of immediate (for ANDI/ORI/XORI) + pub rs1_bits: [Vec; 32], // Bit decomposition of rs1 + pub rs2_bits: [Vec; 32], // Bit decomposition of rs2 + pub imm_bits: [Vec; 32], // Bit decomposition of immediate (for ANDI/ORI/XORI) // OUTPUT decompositions (result bits) - pub and_bits: [Vec; 32], // Bit decomposition of AND result - pub xor_bits: [Vec; 32], // Bit decomposition of XOR result - pub or_bits: [Vec; 32], // Bit decomposition of OR result - + pub and_bits: [Vec; 32], // Bit decomposition of AND result + pub xor_bits: [Vec; 32], // Bit decomposition of XOR result + pub or_bits: [Vec; 32], // Bit decomposition of OR result + // ========================================================================= // Byte decomposition for 8-bit lookup tables (NEW - replaces bit constraints) // These columns enable 4 lookups per 32-bit operation instead of 32 polynomial constraints // ========================================================================= - /// Input byte decomposition: rs1 = rs1_bytes[0] + 256*rs1_bytes[1] + ... pub rs1_bytes: [Vec; 4], /// Input byte decomposition: rs2 = rs2_bytes[0] + 256*rs2_bytes[1] + ... @@ -313,7 +312,7 @@ impl TraceColumns { lt_result: Vec::new(), eq_result: Vec::new(), branch_taken: Vec::new(), - + // Initialize bit arrays (32 empty vectors for each) rs1_bits: std::array::from_fn(|_| Vec::new()), rs2_bits: std::array::from_fn(|_| Vec::new()), @@ -321,7 +320,7 @@ impl TraceColumns { and_bits: std::array::from_fn(|_| Vec::new()), xor_bits: std::array::from_fn(|_| Vec::new()), or_bits: std::array::from_fn(|_| Vec::new()), - + // Initialize byte arrays (4 bytes per 32-bit value) rs1_bytes: std::array::from_fn(|_| Vec::new()), rs2_bytes: std::array::from_fn(|_| Vec::new()), @@ -368,18 +367,47 @@ impl TraceColumns { let funct3 = (row.instr.bits >> 12) & 0x7; let funct7 = (row.instr.bits >> 25) & 0x7F; - let mut is_add = 0; let mut is_sub = 0; let mut is_and = 0; let mut is_or = 0; let mut is_xor = 0; - let mut is_sll = 0; let mut is_srl = 0; let mut is_sra = 0; let mut is_slt = 0; let mut is_sltu = 0; - let mut is_addi = 0; let mut is_andi = 0; let mut is_ori = 0; let mut is_xori = 0; - let mut is_slti = 0; let mut is_sltiu = 0; let mut is_slli = 0; let mut is_srli = 0; let mut is_srai = 0; - let mut is_lui = 0; let mut is_auipc = 0; - let mut is_beq = 0; let mut is_bne = 0; let mut is_blt = 0; let mut is_bge = 0; let mut is_bltu = 0; let mut is_bgeu = 0; - let mut is_jal = 0; let mut is_jalr = 0; - let mut is_mul = 0; let mut is_mulh = 0; let mut is_mulhsu = 0; let mut is_mulhu = 0; - let mut is_div = 0; let mut is_divu = 0; let mut is_rem = 0; let mut is_remu = 0; + let mut is_add = 0; + let mut is_sub = 0; + let mut is_and = 0; + let mut is_or = 0; + let mut is_xor = 0; + let mut is_sll = 0; + let mut is_srl = 0; + let mut is_sra = 0; + let mut is_slt = 0; + let mut is_sltu = 0; + let mut is_addi = 0; + let mut is_andi = 0; + let mut is_ori = 0; + let mut is_xori = 0; + let mut is_slti = 0; + let mut is_sltiu = 0; + let mut is_slli = 0; + let mut is_srli = 0; + let mut is_srai = 0; + let mut is_lui = 0; + let mut is_auipc = 0; + let mut is_beq = 0; + let mut is_bne = 0; + let mut is_blt = 0; + let mut is_bge = 0; + let mut is_bltu = 0; + let mut is_bgeu = 0; + let mut is_jal = 0; + let mut is_jalr = 0; + let mut is_mul = 0; + let mut is_mulh = 0; + let mut is_mulhsu = 0; + let mut is_mulhu = 0; + let mut is_div = 0; + let mut is_divu = 0; + let mut is_rem = 0; + let mut is_remu = 0; match opcode { - 0x33 => { // R-Type + 0x33 => { + // R-Type match (funct3, funct7) { (0x0, 0x00) => is_add = 1, (0x0, 0x20) => is_sub = 1, @@ -401,21 +429,29 @@ impl TraceColumns { (0x7, 0x01) => is_remu = 1, _ => {} } - }, - 0x13 => { // I-Type + } + 0x13 => { + // I-Type match funct3 { 0x0 => is_addi = 1, 0x1 => is_slli = 1, 0x2 => is_slti = 1, 0x3 => is_sltiu = 1, 0x4 => is_xori = 1, - 0x5 => if funct7 == 0x00 { is_srli = 1 } else if funct7 == 0x20 { is_srai = 1 }, + 0x5 => { + if funct7 == 0x00 { + is_srli = 1 + } else if funct7 == 0x20 { + is_srai = 1 + } + } 0x6 => is_ori = 1, 0x7 => is_andi = 1, _ => {} } - }, - 0x63 => { // Branch + } + 0x63 => { + // Branch match funct3 { 0x0 => is_beq = 1, 0x1 => is_bne = 1, @@ -425,7 +461,7 @@ impl TraceColumns { 0x7 => is_bgeu = 1, _ => {} } - }, + } 0x37 => is_lui = 1, 0x17 => is_auipc = 1, 0x6F => is_jal = 1, @@ -472,23 +508,45 @@ impl TraceColumns { cols.is_remu.push(M31::new(is_remu)); // Memory operation - let (mem_addr, mem_val, is_lb, is_lbu, is_lh, is_lhu, is_lw, is_sb, is_sh, is_sw, sb_carry_val) = match row.mem_op { + let ( + mem_addr, + mem_val, + is_lb, + is_lbu, + is_lh, + is_lhu, + is_lw, + is_sb, + is_sh, + is_sw, + sb_carry_val, + ) = match row.mem_op { zp1_executor::trace::MemOp::None => (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0), - zp1_executor::trace::MemOp::LoadByte { addr, value, signed } => { + zp1_executor::trace::MemOp::LoadByte { + addr, + value, + signed, + } => { if signed { (addr, value as u32, 1, 0, 0, 0, 0, 0, 0, 0, 0) } else { (addr, value as u32, 0, 1, 0, 0, 0, 0, 0, 0, 0) } - }, - zp1_executor::trace::MemOp::LoadHalf { addr, value, signed } => { + } + zp1_executor::trace::MemOp::LoadHalf { + addr, + value, + signed, + } => { if signed { (addr, value as u32, 0, 0, 1, 0, 0, 0, 0, 0, 0) } else { (addr, value as u32, 0, 0, 0, 1, 0, 0, 0, 0, 0) } - }, - zp1_executor::trace::MemOp::LoadWord { addr, value } => (addr, value, 0, 0, 0, 0, 1, 0, 0, 0, 0), + } + zp1_executor::trace::MemOp::LoadWord { addr, value } => { + (addr, value, 0, 0, 0, 0, 1, 0, 0, 0, 0) + } zp1_executor::trace::MemOp::StoreByte { addr, value } => { // sb_carry = (rs2_val_lo - mem_val_lo) / 256 // rs2_val_lo is the lower 16 bits of rs2. @@ -497,26 +555,42 @@ impl TraceColumns { let rs2_val = row.regs[row.instr.rs2 as usize]; let carry = (rs2_val & 0xFFFF) >> 8; (addr, value as u32, 0, 0, 0, 0, 0, 1, 0, 0, carry) - }, - zp1_executor::trace::MemOp::StoreHalf { addr, value } => (addr, value as u32, 0, 0, 0, 0, 0, 0, 1, 0, 0), - zp1_executor::trace::MemOp::StoreWord { addr, value } => (addr, value, 0, 0, 0, 0, 0, 0, 0, 1, 0), + } + zp1_executor::trace::MemOp::StoreHalf { addr, value } => { + (addr, value as u32, 0, 0, 0, 0, 0, 0, 1, 0, 0) + } + zp1_executor::trace::MemOp::StoreWord { addr, value } => { + (addr, value, 0, 0, 0, 0, 0, 0, 0, 1, 0) + } // Keccak256 is delegated to a separate circuit, so it doesn't appear in the main trace // The delegation link is recorded separately - zp1_executor::trace::MemOp::Keccak256 { .. } => (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0), + zp1_executor::trace::MemOp::Keccak256 { .. } => { + (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0) + } // ECRECOVER is also delegated to a separate circuit - zp1_executor::trace::MemOp::Ecrecover { .. } => (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0), + zp1_executor::trace::MemOp::Ecrecover { .. } => { + (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0) + } // SHA-256 is also delegated to a separate circuit - zp1_executor::trace::MemOp::Sha256 { .. } => (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0), + zp1_executor::trace::MemOp::Sha256 { .. } => { + (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0) + } // RIPEMD-160 is also delegated to a separate circuit - zp1_executor::trace::MemOp::Ripemd160 { .. } => (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0), - zp1_executor::trace::MemOp::Modexp { .. } => (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0), - zp1_executor::trace::MemOp::Blake2b { .. } => (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0), + zp1_executor::trace::MemOp::Ripemd160 { .. } => { + (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0) + } + zp1_executor::trace::MemOp::Modexp { .. } => { + (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0) + } + zp1_executor::trace::MemOp::Blake2b { .. } => { + (0u32, 0u32, 0, 0, 0, 0, 0, 0, 0, 0, 0) + } }; cols.mem_addr_lo.push(M31::new(mem_addr & 0xFFFF)); cols.mem_addr_hi.push(M31::new((mem_addr >> 16) & 0xFFFF)); cols.mem_val_lo.push(M31::new(mem_val & 0xFFFF)); cols.mem_val_hi.push(M31::new((mem_val >> 16) & 0xFFFF)); - + cols.is_lb.push(M31::new(is_lb)); cols.is_lbu.push(M31::new(is_lbu)); cols.is_lh.push(M31::new(is_lh)); @@ -534,21 +608,33 @@ impl TraceColumns { // Auxiliary witnesses let rs1_val = row.regs[row.instr.rs1 as usize]; let rs2_val = row.regs[row.instr.rs2 as usize]; - + // Carry (ADD) let carry = if is_add == 1 { let rs1_lo = rs1_val & 0xFFFF; let rs2_lo = rs2_val & 0xFFFF; - if rs1_lo + rs2_lo > 0xFFFF { 1 } else { 0 } - } else { 0 }; + if rs1_lo + rs2_lo > 0xFFFF { + 1 + } else { + 0 + } + } else { + 0 + }; cols.carry.push(M31::new(carry)); // Borrow (SUB) let borrow = if is_sub == 1 { let rs1_lo = rs1_val & 0xFFFF; let rs2_lo = rs2_val & 0xFFFF; - if rs1_lo < rs2_lo { 1 } else { 0 } - } else { 0 }; + if rs1_lo < rs2_lo { + 1 + } else { + 0 + } + } else { + 0 + }; cols.borrow.push(M31::new(borrow)); // Quotient/Remainder (DIV/REM) @@ -556,13 +642,15 @@ impl TraceColumns { // Simplified: assume signed division for DIV/REM, unsigned for DIVU/REMU // But for now, just use signed logic as placeholder or match instruction if rs2_val == 0 { - (0xFFFFFFFF, rs1_val) + (0xFFFFFFFF, rs1_val) } else { let q = (rs1_val as i32).wrapping_div(rs2_val as i32) as u32; let r = (rs1_val as i32).wrapping_rem(rs2_val as i32) as u32; (q, r) } - } else { (0, 0) }; + } else { + (0, 0) + }; cols.quotient_lo.push(M31::new(quot & 0xFFFF)); cols.quotient_hi.push(M31::new((quot >> 16) & 0xFFFF)); cols.remainder_lo.push(M31::new(rem & 0xFFFF)); @@ -572,24 +660,56 @@ impl TraceColumns { let lt = if is_slt == 1 || is_sltu == 1 || is_slti == 1 || is_sltiu == 1 { row.rd_val } else if is_blt == 1 || is_bge == 1 || is_bltu == 1 || is_bgeu == 1 { - match funct3 { - 4 | 5 => if (rs1_val as i32) < (rs2_val as i32) { 1 } else { 0 }, - 6 | 7 => if rs1_val < rs2_val { 1 } else { 0 }, - _ => 0, + match funct3 { + 4 | 5 => { + if (rs1_val as i32) < (rs2_val as i32) { + 1 + } else { + 0 + } + } + 6 | 7 => { + if rs1_val < rs2_val { + 1 + } else { + 0 + } } - } else { 0 }; + _ => 0, + } + } else { + 0 + }; cols.lt_result.push(M31::new(lt)); let eq = if is_beq == 1 || is_bne == 1 { - if rs1_val == rs2_val { 1 } else { 0 } - } else { 0 }; + if rs1_val == rs2_val { + 1 + } else { + 0 + } + } else { + 0 + }; cols.eq_result.push(M31::new(eq)); - let branch_taken = if is_beq == 1 || is_bne == 1 || is_blt == 1 || is_bge == 1 || is_bltu == 1 || is_bgeu == 1 { - if row.next_pc != row.pc.wrapping_add(4) { 1 } else { 0 } - } else { 0 }; + let branch_taken = if is_beq == 1 + || is_bne == 1 + || is_blt == 1 + || is_bge == 1 + || is_bltu == 1 + || is_bgeu == 1 + { + if row.next_pc != row.pc.wrapping_add(4) { + 1 + } else { + 0 + } + } else { + 0 + }; cols.branch_taken.push(M31::new(branch_taken)); - + // Bitwise operation bit decomposition // INPUT decomposition (rs1, rs2, and immediate) let imm_val = row.instr.imm as u32; @@ -598,7 +718,7 @@ impl TraceColumns { cols.rs2_bits[i].push(M31::new((rs2_val >> i) & 1)); cols.imm_bits[i].push(M31::new((imm_val >> i) & 1)); } - + // OUTPUT decomposition (AND/XOR/OR results) let and_result = rs1_val & rs2_val; let xor_result = rs1_val ^ rs2_val; @@ -608,7 +728,7 @@ impl TraceColumns { cols.xor_bits[i].push(M31::new((xor_result >> i) & 1)); cols.or_bits[i].push(M31::new((or_result >> i) & 1)); } - + // BYTE decomposition for 8-bit lookup table verification // Each 32-bit value is split into 4 bytes for efficient lookup for i in 0..4 { @@ -726,7 +846,7 @@ impl TraceColumns { self.lt_result.resize(target, M31::ZERO); self.eq_result.resize(target, M31::ZERO); self.branch_taken.resize(target, M31::ZERO); - + // Pad bit arrays (32 bits each for rs1/rs2 inputs and AND/XOR/OR outputs) for i in 0..32 { self.rs1_bits[i].resize(target, M31::ZERO); @@ -736,7 +856,7 @@ impl TraceColumns { self.xor_bits[i].resize(target, M31::ZERO); self.or_bits[i].resize(target, M31::ZERO); } - + // Pad byte arrays (4 bytes each for lookup table integration) for i in 0..4 { self.rs1_bytes[i].resize(target, M31::ZERO); @@ -844,25 +964,25 @@ impl TraceColumns { .chain(self.xor_result_bytes.iter().map(|v| v.clone())) .collect() } - + // ========================================================================= // MEMORY-EFFICIENT ACCESS METHODS // These methods avoid cloning data, reducing memory usage by ~50% // ========================================================================= - + /// Convert to columns by taking ownership (no cloning). - /// + /// /// This is 2x more memory-efficient than `to_columns()` because it /// moves the data instead of cloning it. Use this when you don't need /// to keep the TraceColumns after conversion. - /// + /// /// # Example /// ```ignore /// let columns = trace.into_columns(); // trace is consumed /// ``` pub fn into_columns(self) -> Vec> { let mut result = Vec::with_capacity(NUM_CPU_COLUMNS); - + // Core columns (moved, not cloned) result.push(self.clk); result.push(self.pc); @@ -941,61 +1061,83 @@ impl TraceColumns { result.push(self.lt_result); result.push(self.eq_result); result.push(self.branch_taken); - + // Bit columns - for col in self.rs1_bits { result.push(col); } - for col in self.rs2_bits { result.push(col); } - for col in self.imm_bits { result.push(col); } - for col in self.and_bits { result.push(col); } - for col in self.xor_bits { result.push(col); } - for col in self.or_bits { result.push(col); } - + for col in self.rs1_bits { + result.push(col); + } + for col in self.rs2_bits { + result.push(col); + } + for col in self.imm_bits { + result.push(col); + } + for col in self.and_bits { + result.push(col); + } + for col in self.xor_bits { + result.push(col); + } + for col in self.or_bits { + result.push(col); + } + // Byte columns - for col in self.rs1_bytes { result.push(col); } - for col in self.rs2_bytes { result.push(col); } - for col in self.and_result_bytes { result.push(col); } - for col in self.or_result_bytes { result.push(col); } - for col in self.xor_result_bytes { result.push(col); } - + for col in self.rs1_bytes { + result.push(col); + } + for col in self.rs2_bytes { + result.push(col); + } + for col in self.and_result_bytes { + result.push(col); + } + for col in self.or_result_bytes { + result.push(col); + } + for col in self.xor_result_bytes { + result.push(col); + } + result } - + /// Estimate memory usage in bytes. - /// + /// /// Useful for monitoring and deciding when to use streaming. pub fn memory_usage(&self) -> usize { let rows = self.len(); let m31_size = std::mem::size_of::(); - + // Count all columns let base_cols = 77; // Base trace columns let bit_cols = 32 * 6; // 6 bit arrays × 32 bits let byte_cols = 4 * 5; // 5 byte arrays × 4 bytes let total_cols = base_cols + bit_cols + byte_cols; - + rows * m31_size * total_cols } - + /// Estimate memory usage in MB. pub fn memory_usage_mb(&self) -> f64 { self.memory_usage() as f64 / (1024.0 * 1024.0) } - + /// Check if trace fits in available memory with given margin. - /// + /// /// Returns true if estimated memory usage is less than (available_mb - margin_mb). pub fn fits_in_memory(&self, available_mb: f64, margin_mb: f64) -> bool { self.memory_usage_mb() < (available_mb - margin_mb) } - + /// Get the number of rows that would fit in given memory budget. - /// + /// /// Useful for chunking large traces. pub fn rows_for_memory_budget(memory_mb: f64) -> usize { let m31_size = std::mem::size_of::(); let total_cols = 77 + 32 * 6 + 4 * 5; // 269 columns let bytes_per_row = m31_size * total_cols; - + let memory_bytes = (memory_mb * 1024.0 * 1024.0) as usize; memory_bytes / bytes_per_row } @@ -1038,10 +1180,9 @@ impl StreamingConfig { memory_budget_mb: memory_mb, } } - + /// Create config for M4 Mac with 24GB RAM (leaves 8GB for system). pub fn for_m4_mac() -> Self { Self::with_memory_mb(16.0 * 1024.0) // 16GB for prover } } -