Skip to content

Commit ad05ca5

Browse files
authored
migrate jump opcode (#1024)
To close #1012 based on #1008 ### circuit stats ```diff +---------------+---------------+---------+-------+-----------+--------+------------+---------------------+ | opcode_name | num_instances | lookups | reads | witnesses | writes | 0_expr_deg | 0_expr_sumcheck_deg | +---------------+---------------+---------+-------+-----------+--------+------------+---------------------+ - | JAL | 0 | 5 | 2 | 11 | 2 | [1: 2] | [] | + | JAL | 0 | 8 | 2 | 13 | 2 | [1: 2] | [] | - | JALR | 0 | 9 | 3 | 22 | 3 | [1: 5] | [2: 2] | + | JALR | 0 | 9 | 3 | 22 | 3 | [1: 4] | [2: 4] | +---------------+---------------+---------+-------+-----------+--------+------------+---------------------+ ```
1 parent 1c48793 commit ad05ca5

File tree

9 files changed

+456
-32
lines changed

9 files changed

+456
-32
lines changed

ceno_zkvm/src/instructions/riscv/insn_base.rs

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use ff_ext::{ExtensionField, FieldInto, SmallField};
33
use itertools::Itertools;
44
use p3::field::{Field, FieldAlgebra};
55

6-
use super::constants::{PC_STEP_SIZE, UINT_LIMBS, UInt};
6+
use super::constants::{BIT_WIDTH, PC_STEP_SIZE, UINT_LIMBS, UInt};
77
use crate::{
88
chip_handler::{
99
AddressExpr, GlobalStateRegisterMachineChipOperations, MemoryChipOperations, MemoryExpr,
@@ -368,6 +368,7 @@ impl WriteMEM {
368368
pub struct MemAddr<E: ExtensionField> {
369369
addr: UInt<E>,
370370
low_bits: Vec<WitIn>,
371+
max_bits: usize,
371372
}
372373

373374
impl<E: ExtensionField> MemAddr<E> {
@@ -393,6 +394,17 @@ impl<E: ExtensionField> MemAddr<E> {
393394
self.addr.address_expr()
394395
}
395396

397+
pub fn uint_unaligned(&self) -> UInt<E> {
398+
UInt::from_exprs_unchecked(self.addr.expr())
399+
}
400+
401+
pub fn uint_align2(&self) -> UInt<E> {
402+
UInt::from_exprs_unchecked(vec![
403+
self.addr.limbs[0].expr() - &self.low_bit_exprs()[0],
404+
self.addr.limbs[1].expr(),
405+
])
406+
}
407+
396408
/// Represent the address aligned to 2 bytes.
397409
pub fn expr_align2(&self) -> AddressExpr<E> {
398410
self.addr.address_expr() - &self.low_bit_exprs()[0]
@@ -404,6 +416,14 @@ impl<E: ExtensionField> MemAddr<E> {
404416
self.addr.address_expr() - &low_bits[1] * 2 - &low_bits[0]
405417
}
406418

419+
pub fn uint_align4(&self) -> UInt<E> {
420+
let low_bits = self.low_bit_exprs();
421+
UInt::from_exprs_unchecked(vec![
422+
self.addr.limbs[0].expr() - &low_bits[1] * 2 - &low_bits[0],
423+
self.addr.limbs[1].expr(),
424+
])
425+
}
426+
407427
/// Expressions of the low bits of the address, LSB-first: [bit_0, bit_1].
408428
pub fn low_bit_exprs(&self) -> Vec<Expression<E>> {
409429
iter::repeat_n(Expression::ZERO, self.n_zeros())
@@ -412,6 +432,14 @@ impl<E: ExtensionField> MemAddr<E> {
412432
}
413433

414434
fn construct(cb: &mut CircuitBuilder<E>, n_zeros: usize) -> Result<Self, ZKVMError> {
435+
Self::construct_with_max_bits(cb, n_zeros, BIT_WIDTH)
436+
}
437+
438+
pub fn construct_with_max_bits(
439+
cb: &mut CircuitBuilder<E>,
440+
n_zeros: usize,
441+
max_bits: usize,
442+
) -> Result<Self, ZKVMError> {
415443
assert!(n_zeros <= Self::N_LOW_BITS);
416444

417445
// The address as two u16 limbs.
@@ -442,11 +470,19 @@ impl<E: ExtensionField> MemAddr<E> {
442470
cb.assert_ux::<_, _, 14>(|| "mid_u14", mid_u14)?;
443471

444472
// Range check the high limb.
445-
for high_u16 in limbs.iter().skip(1) {
446-
cb.assert_ux::<_, _, 16>(|| "high_u16", high_u16.clone())?;
473+
for (i, high_limb) in limbs.iter().enumerate().skip(1) {
474+
cb.assert_ux_v2(
475+
|| "high_limb",
476+
high_limb.clone(),
477+
(max_bits - i * 16).min(16),
478+
)?;
447479
}
448480

449-
Ok(MemAddr { addr, low_bits })
481+
Ok(MemAddr {
482+
addr,
483+
low_bits,
484+
max_bits,
485+
})
450486
}
451487

452488
pub fn assign_instance(
@@ -470,7 +506,8 @@ impl<E: ExtensionField> MemAddr<E> {
470506
// Range check the high limb.
471507
for i in 1..UINT_LIMBS {
472508
let high_u16 = (addr >> (i * 16)) & 0xffff;
473-
lkm.assert_ux::<16>(high_u16 as u64);
509+
println!("assignment max bit {}", (self.max_bits - i * 16).min(16));
510+
lkm.assert_ux_v2(high_u16 as u64, (self.max_bits - i * 16).min(16));
474511
}
475512

476513
Ok(())
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,22 @@
1+
#[cfg(not(feature = "u16limb_circuit"))]
12
mod jal;
3+
#[cfg(feature = "u16limb_circuit")]
4+
mod jal_v2;
5+
6+
#[cfg(not(feature = "u16limb_circuit"))]
27
mod jalr;
8+
#[cfg(feature = "u16limb_circuit")]
9+
mod jalr_v2;
310

11+
#[cfg(not(feature = "u16limb_circuit"))]
412
pub use jal::JalInstruction;
13+
#[cfg(feature = "u16limb_circuit")]
14+
pub use jal_v2::JalInstruction;
15+
16+
#[cfg(not(feature = "u16limb_circuit"))]
517
pub use jalr::JalrInstruction;
18+
#[cfg(feature = "u16limb_circuit")]
19+
pub use jalr_v2::JalrInstruction;
620

721
#[cfg(test)]
822
mod test;
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
use std::marker::PhantomData;
2+
3+
use ff_ext::ExtensionField;
4+
5+
use crate::{
6+
circuit_builder::CircuitBuilder,
7+
error::ZKVMError,
8+
instructions::{
9+
Instruction,
10+
riscv::{
11+
constants::{PC_BITS, UINT_BYTE_LIMBS, UInt8},
12+
j_insn::JInstructionConfig,
13+
},
14+
},
15+
structs::ProgramParams,
16+
utils::split_to_u8,
17+
witness::LkMultiplicity,
18+
};
19+
use ceno_emul::{InsnKind, PC_STEP_SIZE};
20+
use gkr_iop::tables::{LookupTable, ops::XorTable};
21+
use multilinear_extensions::{Expression, ToExpr};
22+
use p3::field::FieldAlgebra;
23+
24+
pub struct JalConfig<E: ExtensionField> {
25+
pub j_insn: JInstructionConfig<E>,
26+
pub rd_written: UInt8<E>,
27+
}
28+
29+
pub struct JalInstruction<E>(PhantomData<E>);
30+
31+
/// JAL instruction circuit
32+
///
33+
/// Note: does not validate that next_pc is aligned by 4-byte increments, which
34+
/// should be verified by lookup argument of the next execution step against
35+
/// the program table
36+
///
37+
/// Assumption: values for valid initial program counter must lie between
38+
/// 2^20 and 2^32 - 2^20 + 2 inclusive, probably enforced by the static
39+
/// program lookup table. If this assumption does not hold, then resulting
40+
/// value for next_pc may not correctly wrap mod 2^32 because of the use
41+
/// of native WitIn values for address space arithmetic.
42+
impl<E: ExtensionField> Instruction<E> for JalInstruction<E> {
43+
type InstructionConfig = JalConfig<E>;
44+
45+
fn name() -> String {
46+
format!("{:?}", InsnKind::JAL)
47+
}
48+
49+
fn construct_circuit(
50+
circuit_builder: &mut CircuitBuilder<E>,
51+
_params: &ProgramParams,
52+
) -> Result<JalConfig<E>, ZKVMError> {
53+
let rd_written = UInt8::new(|| "rd_written", circuit_builder)?;
54+
let rd_exprs = rd_written.expr();
55+
56+
let j_insn = JInstructionConfig::construct_circuit(
57+
circuit_builder,
58+
InsnKind::JAL,
59+
rd_written.register_expr(),
60+
)?;
61+
62+
// constrain rd_exprs [PC_BITS .. u32::BITS] are all 0 via xor
63+
let last_limb_bits = PC_BITS - UInt8::<E>::LIMB_BITS * (UInt8::<E>::NUM_LIMBS - 1);
64+
let additional_bits =
65+
(last_limb_bits..UInt8::<E>::LIMB_BITS).fold(0, |acc, x| acc + (1 << x));
66+
let additional_bits = E::BaseField::from_canonical_u32(additional_bits);
67+
circuit_builder.logic_u8(
68+
LookupTable::Xor,
69+
rd_exprs[3].expr(),
70+
additional_bits.expr(),
71+
rd_exprs[3].expr() + additional_bits.expr(),
72+
)?;
73+
74+
circuit_builder.require_equal(
75+
|| "jal rd_written",
76+
rd_exprs
77+
.iter()
78+
.enumerate()
79+
.fold(Expression::ZERO, |acc, (i, val)| {
80+
acc + val.expr()
81+
* E::BaseField::from_canonical_u32(1 << (i * UInt8::<E>::LIMB_BITS)).expr()
82+
}),
83+
j_insn.vm_state.pc.expr() + PC_STEP_SIZE,
84+
)?;
85+
86+
Ok(JalConfig { j_insn, rd_written })
87+
}
88+
89+
fn assign_instance(
90+
config: &Self::InstructionConfig,
91+
instance: &mut [E::BaseField],
92+
lk_multiplicity: &mut LkMultiplicity,
93+
step: &ceno_emul::StepRecord,
94+
) -> Result<(), ZKVMError> {
95+
config
96+
.j_insn
97+
.assign_instance(instance, lk_multiplicity, step)?;
98+
99+
let rd_written = split_to_u8(step.rd().unwrap().value.after);
100+
config.rd_written.assign_limbs(instance, &rd_written);
101+
for val in &rd_written {
102+
lk_multiplicity.assert_ux::<8>(*val as u64);
103+
}
104+
105+
// constrain pc msb limb range via xor
106+
let last_limb_bits = PC_BITS - UInt8::<E>::LIMB_BITS * (UINT_BYTE_LIMBS - 1);
107+
let additional_bits =
108+
(last_limb_bits..UInt8::<E>::LIMB_BITS).fold(0, |acc, x| acc + (1 << x));
109+
lk_multiplicity.logic_u8::<XorTable>(rd_written[3] as u64, additional_bits as u64);
110+
111+
Ok(())
112+
}
113+
}

ceno_zkvm/src/instructions/riscv/jump/jalr.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,6 @@ impl<E: ExtensionField> Instruction<E> for JalrInstruction<E> {
5353
circuit_builder,
5454
InsnKind::JALR,
5555
imm.expr(),
56-
#[cfg(feature = "u16limb_circuit")]
57-
0.into(),
5856
rs1_read.register_expr(),
5957
rd_written.register_expr(),
6058
true,

0 commit comments

Comments
 (0)