Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 156 additions & 41 deletions crates/air/src/rv32im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -576,20 +576,15 @@ impl ConstraintEvaluator {
let mut rd_reconstructed = M31::ZERO;
let mut and_check = M31::ZERO;

for i in 0..31 { // First 31 bits (fit in M31)
// Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1)
for i in 0..31 {
let pow2 = M31::new(1 << i);
rs1_reconstructed += row.rs1_bits[i] * pow2;
rs2_reconstructed += row.rs2_bits[i] * pow2;
rd_reconstructed += row.and_bits[i] * pow2;
// AND logic: and_bits[i] = rs1_bits[i] * rs2_bits[i]
and_check += row.and_bits[i] - row.rs1_bits[i] * row.rs2_bits[i];
}
// Bit 31 separately to handle field overflow
let pow2_30 = M31::new(1 << 30);
rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2);
rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2);
rd_reconstructed += row.and_bits[31] * pow2_30 * M31::new(2);
and_check += row.and_bits[31] - row.rs1_bits[31] * row.rs2_bits[31];

// All 4 checks in one constraint:
// (rs1 reconstruction) + (rs2 reconstruction) + (AND logic) + (rd reconstruction)
Expand Down Expand Up @@ -619,6 +614,7 @@ impl ConstraintEvaluator {
let mut rd_reconstructed = M31::ZERO;
let mut or_check = M31::ZERO;

// Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1)
for i in 0..31 {
let pow2 = M31::new(1 << i);
rs1_reconstructed += row.rs1_bits[i] * pow2;
Expand All @@ -628,13 +624,6 @@ impl ConstraintEvaluator {
let expected_or = row.rs1_bits[i] + row.rs2_bits[i] - row.rs1_bits[i] * row.rs2_bits[i];
or_check += row.or_bits[i] - expected_or;
}
// Bit 31
let pow2_30 = M31::new(1 << 30);
rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2);
rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2);
rd_reconstructed += row.or_bits[31] * pow2_30 * M31::new(2);
let expected_or = row.rs1_bits[31] + row.rs2_bits[31] - row.rs1_bits[31] * row.rs2_bits[31];
or_check += row.or_bits[31] - expected_or;

row.is_or * (
(rs1_full - rs1_reconstructed) +
Expand Down Expand Up @@ -662,6 +651,7 @@ impl ConstraintEvaluator {
let mut rd_reconstructed = M31::ZERO;
let mut xor_check = M31::ZERO;

// Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1)
for i in 0..31 {
let pow2 = M31::new(1 << i);
rs1_reconstructed += row.rs1_bits[i] * pow2;
Expand All @@ -671,13 +661,6 @@ impl ConstraintEvaluator {
let expected_xor = row.rs1_bits[i] + row.rs2_bits[i] - M31::new(2) * row.rs1_bits[i] * row.rs2_bits[i];
xor_check += row.xor_bits[i] - expected_xor;
}
// Bit 31
let pow2_30 = M31::new(1 << 30);
rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2);
rs2_reconstructed += row.rs2_bits[31] * pow2_30 * M31::new(2);
rd_reconstructed += row.xor_bits[31] * pow2_30 * M31::new(2);
let expected_xor = row.rs1_bits[31] + row.rs2_bits[31] - M31::new(2) * row.rs1_bits[31] * row.rs2_bits[31];
xor_check += row.xor_bits[31] - expected_xor;

row.is_xor * (
(rs1_full - rs1_reconstructed) +
Expand Down Expand Up @@ -909,6 +892,7 @@ impl ConstraintEvaluator {
let mut rd_reconstructed = M31::ZERO;
let mut and_check = M31::ZERO;

// Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1)
for i in 0..31 {
let pow2 = M31::new(1 << i);
rs1_reconstructed += row.rs1_bits[i] * pow2;
Expand All @@ -917,12 +901,6 @@ impl ConstraintEvaluator {
// AND logic: and_bits[i] = rs1_bits[i] * imm_bits[i]
and_check += row.and_bits[i] - row.rs1_bits[i] * row.imm_bits[i];
}
// Bit 31
let pow2_30 = M31::new(1 << 30);
rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2);
imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2);
rd_reconstructed += row.and_bits[31] * pow2_30 * M31::new(2);
and_check += row.and_bits[31] - row.rs1_bits[31] * row.imm_bits[31];

row.is_andi * (
(rs1_full - rs1_reconstructed) +
Expand Down Expand Up @@ -950,6 +928,7 @@ impl ConstraintEvaluator {
let mut rd_reconstructed = M31::ZERO;
let mut or_check = M31::ZERO;

// Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1)
for i in 0..31 {
let pow2 = M31::new(1 << i);
rs1_reconstructed += row.rs1_bits[i] * pow2;
Expand All @@ -959,13 +938,6 @@ impl ConstraintEvaluator {
let expected_or = row.rs1_bits[i] + row.imm_bits[i] - row.rs1_bits[i] * row.imm_bits[i];
or_check += row.or_bits[i] - expected_or;
}
// Bit 31
let pow2_30 = M31::new(1 << 30);
rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2);
imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2);
rd_reconstructed += row.or_bits[31] * pow2_30 * M31::new(2);
let expected_or = row.rs1_bits[31] + row.imm_bits[31] - row.rs1_bits[31] * row.imm_bits[31];
or_check += row.or_bits[31] - expected_or;

row.is_ori * (
(rs1_full - rs1_reconstructed) +
Expand Down Expand Up @@ -993,6 +965,7 @@ impl ConstraintEvaluator {
let mut rd_reconstructed = M31::ZERO;
let mut xor_check = M31::ZERO;

// Only use bits 0-30 to avoid M31 field overflow (2^31 mod (2^31-1) = 1)
for i in 0..31 {
let pow2 = M31::new(1 << i);
rs1_reconstructed += row.rs1_bits[i] * pow2;
Expand All @@ -1002,13 +975,6 @@ impl ConstraintEvaluator {
let expected_xor = row.rs1_bits[i] + row.imm_bits[i] - M31::new(2) * row.rs1_bits[i] * row.imm_bits[i];
xor_check += row.xor_bits[i] - expected_xor;
}
// Bit 31
let pow2_30 = M31::new(1 << 30);
rs1_reconstructed += row.rs1_bits[31] * pow2_30 * M31::new(2);
imm_reconstructed += row.imm_bits[31] * pow2_30 * M31::new(2);
rd_reconstructed += row.xor_bits[31] * pow2_30 * M31::new(2);
let expected_xor = row.rs1_bits[31] + row.imm_bits[31] - M31::new(2) * row.rs1_bits[31] * row.imm_bits[31];
xor_check += row.xor_bits[31] - expected_xor;

row.is_xori * (
(rs1_full - rs1_reconstructed) +
Expand Down Expand Up @@ -1977,4 +1943,153 @@ mod tests {
let c = ConstraintEvaluator::xor_constraint_lookup(&row);
assert_eq!(c, M31::ZERO, "Lookup XOR constraint should be satisfied");
}

/// Test that the fix prevents collision attacks by only using 31 bits.
/// Previously, values differing in bit 31 could collide in M31 representation.
#[test]
fn test_and_constraint_31bit_collision_prevented() {
// Test case: Simple 31-bit AND operation
// 0x00000055 & 0x000000FF = 0x00000055
let mut row1 = CpuTraceRow::default();
row1.is_and = M31::ONE;
row1.rs1_val_lo = M31::new(0x0055);
row1.rs1_val_hi = M31::ZERO;
row1.rs2_val_lo = M31::new(0x00FF);
row1.rs2_val_hi = M31::ZERO;

// Set bits for rs1 = 0x55 (bits 0,2,4,6 set)
for i in 0..31 {
row1.rs1_bits[i] = if i < 8 && (0x55 & (1 << i)) != 0 { M31::ONE } else { M31::ZERO };
}
row1.rs1_bits[31] = M31::ZERO;

// Set bits for rs2 = 0xFF (bits 0-7 set)
for i in 0..31 {
row1.rs2_bits[i] = if i < 8 { M31::ONE } else { M31::ZERO };
}
row1.rs2_bits[31] = M31::ZERO;

// Result: 0x55 & 0xFF = 0x55
row1.rd_val_lo = M31::new(0x0055);
row1.rd_val_hi = M31::ZERO;
for i in 0..31 {
row1.and_bits[i] = if i < 8 && (0x55 & (1 << i)) != 0 { M31::ONE } else { M31::ZERO };
}
row1.and_bits[31] = M31::ZERO;

let c1 = ConstraintEvaluator::and_constraint(&row1);
assert_eq!(c1, M31::ZERO, "AND constraint should be satisfied for 31-bit values");

// Test case 2: Demonstrate that values with bit 31 set (> 2^31-1) cannot be properly
// represented. The constraint system now treats values as 31-bit, preventing the
// collision attack where 0x00000001 and 0x80000000 would both map to M31(1).
// We verify that attempting to set a value with the high bit causes a constraint failure.
let mut row2 = CpuTraceRow::default();
row2.is_and = M31::ONE;

// Try to represent 0x80000000 using the hi/lo split: lo=0, hi=0x8000
// But with only 31-bit reconstruction, this should fail
row2.rs1_val_lo = M31::ZERO;
row2.rs1_val_hi = M31::new(0x8000); // This represents bit 31 being set
row2.rs2_val_lo = M31::new(0xFFFF);
row2.rs2_val_hi = M31::new(0x7FFF); // Max 31-bit value in high part

// Bit decomposition: all zeros (can't represent 0x80000000 with 31 bits)
for i in 0..32 {
row2.rs1_bits[i] = M31::ZERO;
row2.rs2_bits[i] = if i < 31 { M31::ONE } else { M31::ZERO };
row2.and_bits[i] = M31::ZERO;
}

row2.rd_val_lo = M31::ZERO;
row2.rd_val_hi = M31::new(0x8000);

let c2 = ConstraintEvaluator::and_constraint(&row2);
// The constraint will fail because:
// rs1_full = 0 + 0x8000 * 2^16 = 0x80000000 in regular arithmetic
// But rs1_reconstructed from bits = 0 (since all bits are 0 and bit 31 is ignored)
// So rs1_full - rs1_reconstructed ≠ 0 in M31
assert_ne!(c2, M31::ZERO, "AND constraint should fail when value exceeds 31-bit range");
}

#[test]
fn test_xor_constraint_31bit_collision_prevented() {
// Test a valid 31-bit XOR operation
let mut row = CpuTraceRow::default();
row.is_xor = M31::ONE;

// XOR: 0x00000055 ^ 0x000000AA = 0x000000FF (all within 31 bits)
row.rs1_val_lo = M31::new(0x0055);
row.rs1_val_hi = M31::ZERO;
row.rs2_val_lo = M31::new(0x00AA);
row.rs2_val_hi = M31::ZERO;
row.rd_val_lo = M31::new(0x00FF);
row.rd_val_hi = M31::ZERO;

// Set bits for rs1 = 0x55 (bits 0,2,4,6 set)
for i in 0..31 {
row.rs1_bits[i] = if i < 8 && (0x55 & (1 << i)) != 0 { M31::ONE } else { M31::ZERO };
}
row.rs1_bits[31] = M31::ZERO;

// Set bits for rs2 = 0xAA (bits 1,3,5,7 set)
for i in 0..31 {
row.rs2_bits[i] = if i < 8 && (0xAA & (1 << i)) != 0 { M31::ONE } else { M31::ZERO };
}
row.rs2_bits[31] = M31::ZERO;

// Set bits for result = 0xFF (bits 0-7 set)
for i in 0..31 {
row.xor_bits[i] = if i < 8 { M31::ONE } else { M31::ZERO };
}
row.xor_bits[31] = M31::ZERO;

let c = ConstraintEvaluator::xor_constraint(&row);
assert_eq!(c, M31::ZERO, "XOR constraint should be satisfied for valid 31-bit values");
}

#[test]
fn test_or_constraint_31bit_values() {
// Test a valid 31-bit OR operation
let mut row = CpuTraceRow::default();
row.is_or = M31::ONE;

// OR: 0x12340000 | 0x00005678 = 0x12345678 (within 31 bits)
row.rs1_val_lo = M31::ZERO;
row.rs1_val_hi = M31::new(0x1234);
row.rs2_val_lo = M31::new(0x5678);
row.rs2_val_hi = M31::ZERO;
row.rd_val_lo = M31::new(0x5678);
row.rd_val_hi = M31::new(0x1234);

// Set bits for rs1 (high 16 bits)
for i in 0..31 {
let val = if i >= 16 && i < 31 { (0x1234 >> (i - 16)) & 1 } else { 0 };
row.rs1_bits[i] = if val != 0 { M31::ONE } else { M31::ZERO };
}
row.rs1_bits[31] = M31::ZERO;

// Set bits for rs2 (low 16 bits)
for i in 0..31 {
let val = if i < 16 { (0x5678 >> i) & 1 } else { 0 };
row.rs2_bits[i] = if val != 0 { M31::ONE } else { M31::ZERO };
}
row.rs2_bits[31] = M31::ZERO;

// Set bits for result (combined)
for i in 0..31 {
let val = if i < 16 {
(0x5678 >> i) & 1
} else if i < 31 {
(0x1234 >> (i - 16)) & 1
} else {
0
};
row.or_bits[i] = if val != 0 { M31::ONE } else { M31::ZERO };
}
row.or_bits[31] = M31::ZERO;

let c = ConstraintEvaluator::or_constraint(&row);
assert_eq!(c, M31::ZERO, "OR constraint should be satisfied for valid 31-bit values");
}
}