Skip to content

Commit 31d2715

Browse files
committed
refactor: update flags, add support for extra1
1 parent dadfc24 commit 31d2715

File tree

1 file changed

+82
-67
lines changed

1 file changed

+82
-67
lines changed

constraints/miden-vm/stack.air

+82-67
Original file line numberDiff line numberDiff line change
@@ -4,89 +4,86 @@ mod StackAir
44

55
# Flags for the first bits (op_bits[6], op_bits[5], op_bits[4])
66

7-
fn f_000(op_bits: vector[8]) -> scalar:
7+
fn f_000(op_bits: vector[9]) -> scalar:
88
return !op_bits[6] & !op_bits[5] & !op_bits[4]
99

10-
fn f_001(op_bits: vector[8]) -> scalar:
10+
fn f_001(op_bits: vector[9]) -> scalar:
1111
return !op_bits[6] & !op_bits[5] & op_bits[4]
1212

13-
fn f_010(op_bits: vector[8]) -> scalar:
13+
fn f_010(op_bits: vector[9]) -> scalar:
1414
return !op_bits[6] & op_bits[5] & !op_bits[4]
1515

16-
fn f_011(op_bits: vector[8]) -> scalar:
16+
fn f_011(op_bits: vector[9]) -> scalar:
1717
return !op_bits[6] & op_bits[5] & op_bits[4]
1818

1919
# This flag is equal to f_100
20-
fn f_u32rc(op_bits: vector[8]) -> scalar:
20+
fn f_u32rc(op_bits: vector[9]) -> scalar:
2121
return op_bits[6] & !op_bits[5] & !op_bits[4]
2222

23-
fn f_101(op_bits: vector[8]) -> scalar:
24-
return op_bits[6] & !op_bits[5] & op_bits[4]
25-
2623

2724
# Flags for the four last bits (op_bits[3], op_bits[2], op_bits[1], op_bits[0])
2825

29-
fn f_x0000(op_bits: vector[8]) -> scalar:
26+
fn f_x0000(op_bits: vector[9]) -> scalar:
3027
return !op_bits[3] & !op_bits[2] & !op_bits[1] & !op_bits[0]
3128

32-
fn f_x0001(op_bits: vector[8]) -> scalar:
29+
fn f_x0001(op_bits: vector[9]) -> scalar:
3330
return !op_bits[3] & !op_bits[2] & !op_bits[1] & op_bits[0]
3431

35-
fn f_x0010(op_bits: vector[8]) -> scalar:
32+
fn f_x0010(op_bits: vector[9]) -> scalar:
3633
return !op_bits[3] & !op_bits[2] & op_bits[1] & !op_bits[0]
3734

38-
fn f_x0011(op_bits: vector[8]) -> scalar:
35+
fn f_x0011(op_bits: vector[9]) -> scalar:
3936
return !op_bits[3] & !op_bits[2] & op_bits[1] & op_bits[0]
4037

41-
fn f_x0100(op_bits: vector[8]) -> scalar:
38+
fn f_x0100(op_bits: vector[9]) -> scalar:
4239
return !op_bits[3] & op_bits[2] & !op_bits[1] & !op_bits[0]
4340

44-
fn f_x0101(op_bits: vector[8]) -> scalar:
41+
fn f_x0101(op_bits: vector[9]) -> scalar:
4542
return !op_bits[3] & op_bits[2] & !op_bits[1] & op_bits[0]
4643

47-
fn f_x0110(op_bits: vector[8]) -> scalar:
44+
fn f_x0110(op_bits: vector[9]) -> scalar:
4845
return !op_bits[3] & op_bits[2] & op_bits[1] & !op_bits[0]
4946

50-
fn f_x0111(op_bits: vector[8]) -> scalar:
47+
fn f_x0111(op_bits: vector[9]) -> scalar:
5148
return !op_bits[3] & op_bits[2] & op_bits[1] & op_bits[0]
5249

53-
fn f_x1000(op_bits: vector[8]) -> scalar:
50+
fn f_x1000(op_bits: vector[9]) -> scalar:
5451
return op_bits[3] & !op_bits[2] & !op_bits[1] & !op_bits[0]
5552

56-
fn f_x1001(op_bits: vector[8]) -> scalar:
53+
fn f_x1001(op_bits: vector[9]) -> scalar:
5754
return op_bits[3] & !op_bits[2] & !op_bits[1] & op_bits[0]
5855

59-
fn f_x1010(op_bits: vector[8]) -> scalar:
56+
fn f_x1010(op_bits: vector[9]) -> scalar:
6057
return op_bits[3] & !op_bits[2] & op_bits[1] & !op_bits[0]
6158

62-
fn f_x1011(op_bits: vector[8]) -> scalar:
59+
fn f_x1011(op_bits: vector[9]) -> scalar:
6360
return op_bits[3] & !op_bits[2] & op_bits[1] & op_bits[0]
6461

65-
fn f_x1100(op_bits: vector[8]) -> scalar:
62+
fn f_x1100(op_bits: vector[9]) -> scalar:
6663
return op_bits[3] & op_bits[2] & !op_bits[1] & !op_bits[0]
6764

68-
fn f_x1101(op_bits: vector[8]) -> scalar:
65+
fn f_x1101(op_bits: vector[9]) -> scalar:
6966
return op_bits[3] & op_bits[2] & !op_bits[1] & op_bits[0]
7067

71-
fn f_x1110(op_bits: vector[8]) -> scalar:
68+
fn f_x1110(op_bits: vector[9]) -> scalar:
7269
return op_bits[3] & op_bits[2] & op_bits[1] & !op_bits[0]
7370

74-
fn f_x1111(op_bits: vector[8]) -> scalar:
71+
fn f_x1111(op_bits: vector[9]) -> scalar:
7572
return op_bits[3] & op_bits[2] & op_bits[1] & op_bits[0]
7673

7774

7875
# Composite flags
7976

80-
fn f_shr(op_bits: vector[8]) -> scalar:
77+
fn f_shr(op_bits: vector[9]) -> scalar:
8178
return !op_bits[6] & op_bits[5] & op_bits[4] + f_u32split(op_bits) + f_push(op_bits)
8279

8380
# hasher[5] = op_helpers[3], where hahser[] are decoder columns, which are the same as helper[] -- columns from the stack
84-
fn f_shl(op_bits: vector[8], op_helpers: vector[6]) -> scalar:
81+
fn f_shl(op_bits: vector[9], op_helpers: vector[6]) -> scalar:
8582
let f_add3_mad = op_bits[6] & !op_bits[5] & !op_bits[4] & op_bits[3] & op_bits[2]
8683
let f_split_loop = op_bits[6] & !op_bits[5] & op_bits[4] & op_bits[3] & op_bits[2]
8784
return !op_bits[6] & op_bits[5] & !op_bits[4] + f_add3_mad + f_split_loop + f_repeat(op_bits) + f_end(op_bits) * op_helpers[3]
8885

89-
fn f_ctrl(op_bits: vector[8]) -> scalar:
86+
fn f_ctrl(op_bits: vector[9]) -> scalar:
9087
# flag for SPAN, JOIN, SPLIT, LOOP
9188
let f_sjsl = op_bits[6] & !op_bits[5] & op_bits[4] & op_bits[3]
9289

@@ -96,7 +93,7 @@ fn f_ctrl(op_bits: vector[8]) -> scalar:
9693
return f_sjsl + f_errh + f_call(op_bits) + f_syscall(op_bits)
9794

9895

99-
fn compute_op_flags(op_bits: vector[8]) -> vector[88]:
96+
fn compute_op_flags(op_bits: vector[9]) -> vector[88]:
10097
return [
10198
# No stack shift operations
10299
f_000(op_bits) & f_x0000(op_bits), # noop
@@ -178,25 +175,27 @@ fn compute_op_flags(op_bits: vector[8]) -> vector[88]:
178175
f_u32rc(op_bits) & op_bits[3] & op_bits[2] & !op_bits[1], # u32add3
179176
f_u32rc(op_bits) & op_bits[3] & op_bits[2] & op_bits[1], # u32madd
180177

178+
181179
# High-degree operations
182-
f_101(op_bits) & !op_bits[3] & !op_bits[2] & !op_bits[1], # hperm
183-
f_101(op_bits) & !op_bits[3] & !op_bits[2] & op_bits[1], # mpverify
184-
f_101(op_bits) & !op_bits[3] & op_bits[2] & !op_bits[1], # pipe
185-
f_101(op_bits) & !op_bits[3] & op_bits[2] & op_bits[1], # mstream
186-
f_101(op_bits) & op_bits[3] & !op_bits[2] & !op_bits[1], # span
187-
f_101(op_bits) & op_bits[3] & !op_bits[2] & op_bits[1], # join
188-
f_101(op_bits) & op_bits[3] & op_bits[2] & !op_bits[1], # split
189-
f_101(op_bits) & op_bits[3] & op_bits[2] & op_bits[1], # loop
180+
op_bits[7] & f_x0000(op_bits), # hperm
181+
op_bits[7] & f_x0001(op_bits), # mpverify
182+
op_bits[7] & f_x0010(op_bits), # pipe
183+
op_bits[7] & f_x0011(op_bits), # mstream
184+
op_bits[7] & f_x0100(op_bits), # split
185+
op_bits[7] & f_x0101(op_bits), # loop
186+
op_bits[7] & f_x0110(op_bits), # span
187+
op_bits[7] & f_x0111(op_bits), # join
188+
190189

191190
# Very high-degree operations
192-
op_bits[7] & !op_bits[4] & !op_bits[3] & !op_bits[2], # mrupdate
193-
op_bits[7] & !op_bits[4] & !op_bits[3] & op_bits[2], # push
194-
op_bits[7] & !op_bits[4] & op_bits[3] & !op_bits[2], # syscall
195-
op_bits[7] & !op_bits[4] & op_bits[3] & op_bits[2], # call
196-
op_bits[7] & op_bits[4] & !op_bits[3] & !op_bits[2], # end
197-
op_bits[7] & op_bits[4] & !op_bits[3] & op_bits[2], # repeat
198-
op_bits[7] & op_bits[4] & op_bits[3] & !op_bits[2], # respan
199-
op_bits[7] & op_bits[4] & op_bits[3] & op_bits[2], # halt
191+
op_bits[8] & !op_bits[4] & !op_bits[3] & !op_bits[2], # mrupdate
192+
op_bits[8] & !op_bits[4] & !op_bits[3] & op_bits[2], # push
193+
op_bits[8] & !op_bits[4] & op_bits[3] & !op_bits[2], # syscall
194+
op_bits[8] & !op_bits[4] & op_bits[3] & op_bits[2], # call
195+
op_bits[8] & op_bits[4] & !op_bits[3] & !op_bits[2], # end
196+
op_bits[8] & op_bits[4] & !op_bits[3] & op_bits[2], # repeat
197+
op_bits[8] & op_bits[4] & op_bits[3] & !op_bits[2], # respan
198+
op_bits[8] & op_bits[4] & op_bits[3] & op_bits[2], # halt
200199
]
201200

202201

@@ -215,8 +214,8 @@ ev check_element_validity([op_helpers[6]]):
215214

216215
# Enforces that the last bit of the opcode (op_bits[0]) is always set to 0. This evaluator is used
217216
# for u32 operations where the last bit of the opcode is not used in computation of the flag.
218-
ev b0_is_zero([op_bits[8]]):
219-
enf op_bits[6] & !op_bits[5] & op_bits[0] = 0
217+
ev b0_is_zero([op_bits[9]]):
218+
enf op_bits[6] & !op_bits[5] & !op_bits[4] & op_bits[0] = 0
220219

221220
# Enforces that the last two bits of the opcode (op_bits[0] and op_bits[1]) are always set to 0.
222221
# This evaluator is used for very-high degree operations where the last two bits of the opcode are
@@ -225,13 +224,21 @@ ev b0_b1_is_zero([op_bits[8]]):
225224
enf op_bits[6] & op_bits[5] & op_bits[0] = 0
226225
enf op_bits[6] & op_bits[5] & op_bits[1] = 0
227226

227+
# Enforces that register extra0 is set to 1 when high-degree operations are executed.
228+
ev extra0([op_bits[9]]):
229+
op_bits[7] = 1 when op_bits[6] & !op_bits[5] & op_bits[4]
230+
231+
# Enforces that register extra1 is set to 1 when very high-degree operations are executed.
232+
ev extra1([op_bits[9]]):
233+
op_bits[8] = 1 when op_bits[6] & op_bits[5]
234+
228235

229236
### Stack Air Constraints #########################################################################
230237

231238
# Enforces the constraints on the stack.
232239
# TODO: add docs for columns
233240
# stack_helpers consists of [bookkeeping[0], bookkeeping[1], h0]
234-
# op_bits consists of [op_bits[7], extra]
241+
# op_bits consists of [op_bits[7], extra0, extra1]
235242
ev stack_constraints([stack_top[16], stack_helpers[3], op_bits[8], op_helpers[6], clk, fmp]):
236243
let op_flags = compute_op_flags(op_bits)
237244

@@ -607,57 +614,57 @@ ev clk([s[16], clk]):
607614

608615
# u32 operations
609616

610-
ev u32add([s[16], op_bits[8], op_helpers[6]]):
617+
ev u32add([s[16], op_bits[9], op_helpers[6]]):
611618
enf s[0] + s[1] = 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
612619
enf s[0]' = op_helpers[2]
613620
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
614621
enf s[i]' = s[i] for i in 2..16
615622
enf b0_is_zero([op_bits])
616623

617-
ev u32sub([s[16], op_bits[8], op_helpers[6]]):
624+
ev u32sub([s[16], op_bits[9], op_helpers[6]]):
618625
enf s[1] = s[0] + s[1]' + 2^32 * s[0]'
619626
enf (s[0]')^2 - s[0]' = 0
620627
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
621628
enf s[i]' = s[i] for i in 2..16
622629
enf b0_is_zero([op_bits])
623630

624-
ev u32mul([s[16], op_bits[8], op_helpers[6]]):
631+
ev u32mul([s[16], op_bits[9], op_helpers[6]]):
625632
enf s[0] * s[1] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
626633
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
627634
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
628635
enf check_element_validity([op_helpers])
629636
enf s[i]' = s[i] for i in 2..16
630637
enf b0_is_zero([op_bits])
631638

632-
ev u32div([s[16], op_bits[8], op_helpers[6]]):
639+
ev u32div([s[16], op_bits[9], op_helpers[6]]):
633640
enf s[1] = s[0] * s[1]' + s[0]'
634641
enf s[1] - s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
635642
enf s[0] - s[0]' - 1 = 2^16 * op_helpers[2] + op_helpers[3]
636643
enf s[i]' = s[i] for i in 2..16
637644
enf b0_is_zero([op_bits])
638645

639-
ev u32split([s[16], op_bits[8], op_helpers[6]]):
646+
ev u32split([s[16], op_bits[9], op_helpers[6]]):
640647
enf s[0] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
641648
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
642649
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
643650
enf check_element_validity([op_helpers])
644651
enf s[i + 1]' = s[i] for i in 1..15
645652
enf b0_is_zero([op_bits])
646653

647-
ev u32assert2([s[16], op_bits[8], op_helpers[6]]):
654+
ev u32assert2([s[16], op_bits[9], op_helpers[6]]):
648655
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
649656
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
650657
enf s[i]' = s[i] for i in 0..16
651658
enf b0_is_zero([op_bits])
652659

653-
ev u32add3([s[16], op_bits[8], op_helpers[6]]):
660+
ev u32add3([s[16], op_bits[9], op_helpers[6]]):
654661
enf s[0] + s[1] + s[2] = 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
655662
enf s[0]' = op_helpers[2]
656663
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
657664
enf s[i]' = s[i + 1] for i in 2..15
658665
enf b0_is_zero([op_bits])
659666

660-
ev u32madd([s[16], op_bits[8], op_helpers[6]]):
667+
ev u32madd([s[16], op_bits[9], op_helpers[6]]):
661668
enf s[0] * s[1] + s[2] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
662669
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
663670
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
@@ -668,28 +675,31 @@ ev u32madd([s[16], op_bits[8], op_helpers[6]]):
668675

669676
# High-degree operations
670677

671-
# Bus constraint is implemented in a separate file
672-
ev hperm([s[16], op_bits[8], op_helpers[6]]):
678+
ev hperm([s[16], op_bits[9], op_helpers[6]]):
673679
enf s[i]' = s[i] for i in 12..16
674-
enf b0_is_zero([op_bits])
680+
enf extra0(op_bits)
681+
# Bus constraint is implemented in a separate file
675682

676-
# Bus constraint is implemented in a separate file
677-
ev mpverify([s[16], op_bits[8], op_helpers[6]]):
683+
ev mpverify([s[16], op_bits[9], op_helpers[6]]):
678684
enf s[i]' = s[i] for i in 0..16
679-
enf b0_is_zero([op_bits])
685+
enf extra0(op_bits)
686+
# Bus constraint is implemented in a separate file
687+
680688

681689
# TODO: add constraints
682-
ev pipe([s[16], op_bits[8], op_helpers[6]]):
690+
ev pipe([s[16], op_bits[9], op_helpers[6]]):
683691

684692

685-
# Bus constraint is implemented in a separate file
686-
ev mstream([s[16], op_bits[8], op_helpers[6]]):
693+
ev mstream([s[16], op_bits[9], op_helpers[6]]):
687694
enf s[12]' = s[12] + 2
688695
enf s[i]' = s[i] for i in 8..12
689696
enf s[i]' = s[i] for i in 13..16
697+
enf extra0(op_bits)
698+
# Bus constraint is implemented in a separate file
699+
690700

691701
# TODO: add constraints
692-
ev span([s[16], op_bits[8], op_helpers[6]])
702+
ev span([s[16], op_bits[9], op_helpers[6]])
693703

694704

695705
# TODO: add constraints
@@ -706,13 +716,18 @@ ev loop()
706716

707717
# Very high-degree operations
708718

709-
# Bus constraint is implemented in a separate file
710-
ev mrupdate([s[16], op_bits[8], op_helpers[6]]):
719+
ev mrupdate([s[16], op_bits[9], op_helpers[6]]):
711720
enf s[i]' = s[i] for i in 4..16
712721
enf b0_b1_is_zero([op_bits])
722+
enf extra1(op_bits)
723+
# Bus constraint is implemented in a separate file
724+
713725

714726
ev push([s[16]]):
715727
enf s[i + 1]' = s[i] for i in 0..15
728+
enf b0_b1_is_zero([op_bits])
729+
enf extra1(op_bits)
730+
716731

717732
# TODO: add constraints
718733
ev syscall():

0 commit comments

Comments
 (0)