diff --git a/crates/air/src/rv32im.rs b/crates/air/src/rv32im.rs index 8e98498..f0719aa 100644 --- a/crates/air/src/rv32im.rs +++ b/crates/air/src/rv32im.rs @@ -576,7 +576,8 @@ impl ConstraintEvaluator { let mut rd_reconstructed = M31::ZERO; let mut and_check = M31::ZERO; - for i in 0..31 { // First 31 bits (fit in M31) + // Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1) + for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; rs2_reconstructed += row.rs2_bits[i] * pow2; @@ -584,12 +585,6 @@ impl ConstraintEvaluator { // AND logic: and_bits[i] = rs1_bits[i] * rs2_bits[i] and_check += row.and_bits[i] - row.rs1_bits[i] * row.rs2_bits[i]; } - // Bit 31 separately to handle field overflow - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.and_bits[31] * pow2_30 * M31::new(2); - and_check += row.and_bits[31] - row.rs1_bits[31] * row.rs2_bits[31]; // All 4 checks in one constraint: // (rs1 reconstruction) + (rs2 reconstruction) + (AND logic) + (rd reconstruction) @@ -619,6 +614,7 @@ impl ConstraintEvaluator { let mut rd_reconstructed = M31::ZERO; let mut or_check = M31::ZERO; + // Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1) for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; @@ -628,13 +624,6 @@ impl ConstraintEvaluator { let expected_or = row.rs1_bits[i] + row.rs2_bits[i] - row.rs1_bits[i] * row.rs2_bits[i]; or_check += row.or_bits[i] - expected_or; } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.or_bits[31] * pow2_30 * M31::new(2); - let expected_or = row.rs1_bits[31] + row.rs2_bits[31] - row.rs1_bits[31] * row.rs2_bits[31]; - or_check += row.or_bits[31] - expected_or; row.is_or * ( (rs1_full - rs1_reconstructed) + @@ -662,6 +651,7 @@ impl ConstraintEvaluator { let mut rd_reconstructed = M31::ZERO; let mut xor_check = M31::ZERO; + // Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1) for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; @@ -671,13 +661,6 @@ impl ConstraintEvaluator { let expected_xor = row.rs1_bits[i] + row.rs2_bits[i] - M31::new(2) * row.rs1_bits[i] * row.rs2_bits[i]; xor_check += row.xor_bits[i] - expected_xor; } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.xor_bits[31] * pow2_30 * M31::new(2); - let expected_xor = row.rs1_bits[31] + row.rs2_bits[31] - M31::new(2) * row.rs1_bits[31] * row.rs2_bits[31]; - xor_check += row.xor_bits[31] - expected_xor; row.is_xor * ( (rs1_full - rs1_reconstructed) + @@ -909,6 +892,7 @@ impl ConstraintEvaluator { let mut rd_reconstructed = M31::ZERO; let mut and_check = M31::ZERO; + // Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1) for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; @@ -917,12 +901,6 @@ impl ConstraintEvaluator { // AND logic: and_bits[i] = rs1_bits[i] * imm_bits[i] and_check += row.and_bits[i] - row.rs1_bits[i] * row.imm_bits[i]; } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.and_bits[31] * pow2_30 * M31::new(2); - and_check += row.and_bits[31] - row.rs1_bits[31] * row.imm_bits[31]; row.is_andi * ( (rs1_full - rs1_reconstructed) + @@ -950,6 +928,7 @@ impl ConstraintEvaluator { let mut rd_reconstructed = M31::ZERO; let mut or_check = M31::ZERO; + // Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1) for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; @@ -959,13 +938,6 @@ impl ConstraintEvaluator { let expected_or = row.rs1_bits[i] + row.imm_bits[i] - row.rs1_bits[i] * row.imm_bits[i]; or_check += row.or_bits[i] - expected_or; } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.or_bits[31] * pow2_30 * M31::new(2); - let expected_or = row.rs1_bits[31] + row.imm_bits[31] - row.rs1_bits[31] * row.imm_bits[31]; - or_check += row.or_bits[31] - expected_or; row.is_ori * ( (rs1_full - rs1_reconstructed) + @@ -993,6 +965,7 @@ impl ConstraintEvaluator { let mut rd_reconstructed = M31::ZERO; let mut xor_check = M31::ZERO; + // Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1) for i in 0..31 { let pow2 = M31::new(1 << i); rs1_reconstructed += row.rs1_bits[i] * pow2; @@ -1002,13 +975,6 @@ impl ConstraintEvaluator { let expected_xor = row.rs1_bits[i] + row.imm_bits[i] - M31::new(2) * row.rs1_bits[i] * row.imm_bits[i]; xor_check += row.xor_bits[i] - expected_xor; } - // Bit 31 - let pow2_30 = M31::new(1 << 30); - rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2); - imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2); - rd_reconstructed += row.xor_bits[31] * pow2_30 * M31::new(2); - let expected_xor = row.rs1_bits[31] + row.imm_bits[31] - M31::new(2) * row.rs1_bits[31] * row.imm_bits[31]; - xor_check += row.xor_bits[31] - expected_xor; row.is_xori * ( (rs1_full - rs1_reconstructed) + @@ -1977,4 +1943,153 @@ mod tests { let c = ConstraintEvaluator::xor_constraint_lookup(&row); assert_eq!(c, M31::ZERO, "Lookup XOR constraint should be satisfied"); } + + /// Test that the fix prevents collision attacks by only using 31 bits. + /// Previously, values differing in bit 31 could collide in M31 representation. + #[test] + fn test_and_constraint_31bit_collision_prevented() { + // Test case: Simple 31-bit AND operation + // 0x00000055 & 0x000000FF = 0x00000055 + let mut row1 = CpuTraceRow::default(); + row1.is_and = M31::ONE; + row1.rs1_val_lo = M31::new(0x0055); + row1.rs1_val_hi = M31::ZERO; + row1.rs2_val_lo = M31::new(0x00FF); + row1.rs2_val_hi = M31::ZERO; + + // Set bits for rs1 = 0x55 (bits 0,2,4,6 set) + for i in 0..31 { + row1.rs1_bits[i] = if i < 8 && (0x55 & (1 << i)) != 0 { M31::ONE } else { M31::ZERO }; + } + row1.rs1_bits[31] = M31::ZERO; + + // Set bits for rs2 = 0xFF (bits 0-7 set) + for i in 0..31 { + row1.rs2_bits[i] = if i < 8 { M31::ONE } else { M31::ZERO }; + } + row1.rs2_bits[31] = M31::ZERO; + + // Result: 0x55 & 0xFF = 0x55 + row1.rd_val_lo = M31::new(0x0055); + row1.rd_val_hi = M31::ZERO; + for i in 0..31 { + row1.and_bits[i] = if i < 8 && (0x55 & (1 << i)) != 0 { M31::ONE } else { M31::ZERO }; + } + row1.and_bits[31] = M31::ZERO; + + let c1 = ConstraintEvaluator::and_constraint(&row1); + assert_eq!(c1, M31::ZERO, "AND constraint should be satisfied for 31-bit values"); + + // Test case 2: Demonstrate that values with bit 31 set (> 2^31-1) cannot be properly + // represented. The constraint system now treats values as 31-bit, preventing the + // collision attack where 0x00000001 and 0x80000000 would both map to M31(1). + // We verify that attempting to set a value with the high bit causes a constraint failure. + let mut row2 = CpuTraceRow::default(); + row2.is_and = M31::ONE; + + // Try to represent 0x80000000 using the hi/lo split: lo=0, hi=0x8000 + // But with only 31-bit reconstruction, this should fail + row2.rs1_val_lo = M31::ZERO; + row2.rs1_val_hi = M31::new(0x8000); // This represents bit 31 being set + row2.rs2_val_lo = M31::new(0xFFFF); + row2.rs2_val_hi = M31::new(0x7FFF); // Max 31-bit value in high part + + // Bit decomposition: all zeros (can't represent 0x80000000 with 31 bits) + for i in 0..32 { + row2.rs1_bits[i] = M31::ZERO; + row2.rs2_bits[i] = if i < 31 { M31::ONE } else { M31::ZERO }; + row2.and_bits[i] = M31::ZERO; + } + + row2.rd_val_lo = M31::ZERO; + row2.rd_val_hi = M31::new(0x8000); + + let c2 = ConstraintEvaluator::and_constraint(&row2); + // The constraint will fail because: + // rs1_full = 0 + 0x8000 * 2^16 = 0x80000000 in regular arithmetic + // But rs1_reconstructed from bits = 0 (since all bits are 0 and bit 31 is ignored) + // So rs1_full - rs1_reconstructed ≠ 0 in M31 + assert_ne!(c2, M31::ZERO, "AND constraint should fail when value exceeds 31-bit range"); + } + + #[test] + fn test_xor_constraint_31bit_collision_prevented() { + // Test a valid 31-bit XOR operation + let mut row = CpuTraceRow::default(); + row.is_xor = M31::ONE; + + // XOR: 0x00000055 ^ 0x000000AA = 0x000000FF (all within 31 bits) + row.rs1_val_lo = M31::new(0x0055); + row.rs1_val_hi = M31::ZERO; + row.rs2_val_lo = M31::new(0x00AA); + row.rs2_val_hi = M31::ZERO; + row.rd_val_lo = M31::new(0x00FF); + row.rd_val_hi = M31::ZERO; + + // Set bits for rs1 = 0x55 (bits 0,2,4,6 set) + for i in 0..31 { + row.rs1_bits[i] = if i < 8 && (0x55 & (1 << i)) != 0 { M31::ONE } else { M31::ZERO }; + } + row.rs1_bits[31] = M31::ZERO; + + // Set bits for rs2 = 0xAA (bits 1,3,5,7 set) + for i in 0..31 { + row.rs2_bits[i] = if i < 8 && (0xAA & (1 << i)) != 0 { M31::ONE } else { M31::ZERO }; + } + row.rs2_bits[31] = M31::ZERO; + + // Set bits for result = 0xFF (bits 0-7 set) + for i in 0..31 { + row.xor_bits[i] = if i < 8 { M31::ONE } else { M31::ZERO }; + } + row.xor_bits[31] = M31::ZERO; + + let c = ConstraintEvaluator::xor_constraint(&row); + assert_eq!(c, M31::ZERO, "XOR constraint should be satisfied for valid 31-bit values"); + } + + #[test] + fn test_or_constraint_31bit_values() { + // Test a valid 31-bit OR operation + let mut row = CpuTraceRow::default(); + row.is_or = M31::ONE; + + // OR: 0x12340000 | 0x00005678 = 0x12345678 (within 31 bits) + row.rs1_val_lo = M31::ZERO; + row.rs1_val_hi = M31::new(0x1234); + row.rs2_val_lo = M31::new(0x5678); + row.rs2_val_hi = M31::ZERO; + row.rd_val_lo = M31::new(0x5678); + row.rd_val_hi = M31::new(0x1234); + + // Set bits for rs1 (high 16 bits) + for i in 0..31 { + let val = if i >= 16 && i < 31 { (0x1234 >> (i - 16)) & 1 } else { 0 }; + row.rs1_bits[i] = if val != 0 { M31::ONE } else { M31::ZERO }; + } + row.rs1_bits[31] = M31::ZERO; + + // Set bits for rs2 (low 16 bits) + for i in 0..31 { + let val = if i < 16 { (0x5678 >> i) & 1 } else { 0 }; + row.rs2_bits[i] = if val != 0 { M31::ONE } else { M31::ZERO }; + } + row.rs2_bits[31] = M31::ZERO; + + // Set bits for result (combined) + for i in 0..31 { + let val = if i < 16 { + (0x5678 >> i) & 1 + } else if i < 31 { + (0x1234 >> (i - 16)) & 1 + } else { + 0 + }; + row.or_bits[i] = if val != 0 { M31::ONE } else { M31::ZERO }; + } + row.or_bits[31] = M31::ZERO; + + let c = ConstraintEvaluator::or_constraint(&row); + assert_eq!(c, M31::ZERO, "OR constraint should be satisfied for valid 31-bit values"); + } }