@@ -4,89 +4,86 @@ mod StackAir
4
4
5
5
# Flags for the first bits (op_bits[6], op_bits[5], op_bits[4])
6
6
7
- fn f_000(op_bits: vector[8 ]) -> scalar:
7
+ fn f_000(op_bits: vector[9 ]) -> scalar:
8
8
return !op_bits[6] & !op_bits[5] & !op_bits[4]
9
9
10
- fn f_001(op_bits: vector[8 ]) -> scalar:
10
+ fn f_001(op_bits: vector[9 ]) -> scalar:
11
11
return !op_bits[6] & !op_bits[5] & op_bits[4]
12
12
13
- fn f_010(op_bits: vector[8 ]) -> scalar:
13
+ fn f_010(op_bits: vector[9 ]) -> scalar:
14
14
return !op_bits[6] & op_bits[5] & !op_bits[4]
15
15
16
- fn f_011(op_bits: vector[8 ]) -> scalar:
16
+ fn f_011(op_bits: vector[9 ]) -> scalar:
17
17
return !op_bits[6] & op_bits[5] & op_bits[4]
18
18
19
19
# This flag is equal to f_100
20
- fn f_u32rc(op_bits: vector[8 ]) -> scalar:
20
+ fn f_u32rc(op_bits: vector[9 ]) -> scalar:
21
21
return op_bits[6] & !op_bits[5] & !op_bits[4]
22
22
23
- fn f_101(op_bits: vector[8]) -> scalar:
24
- return op_bits[6] & !op_bits[5] & op_bits[4]
25
-
26
23
27
24
# Flags for the four last bits (op_bits[3], op_bits[2], op_bits[1], op_bits[0])
28
25
29
- fn f_x0000(op_bits: vector[8 ]) -> scalar:
26
+ fn f_x0000(op_bits: vector[9 ]) -> scalar:
30
27
return !op_bits[3] & !op_bits[2] & !op_bits[1] & !op_bits[0]
31
28
32
- fn f_x0001(op_bits: vector[8 ]) -> scalar:
29
+ fn f_x0001(op_bits: vector[9 ]) -> scalar:
33
30
return !op_bits[3] & !op_bits[2] & !op_bits[1] & op_bits[0]
34
31
35
- fn f_x0010(op_bits: vector[8 ]) -> scalar:
32
+ fn f_x0010(op_bits: vector[9 ]) -> scalar:
36
33
return !op_bits[3] & !op_bits[2] & op_bits[1] & !op_bits[0]
37
34
38
- fn f_x0011(op_bits: vector[8 ]) -> scalar:
35
+ fn f_x0011(op_bits: vector[9 ]) -> scalar:
39
36
return !op_bits[3] & !op_bits[2] & op_bits[1] & op_bits[0]
40
37
41
- fn f_x0100(op_bits: vector[8 ]) -> scalar:
38
+ fn f_x0100(op_bits: vector[9 ]) -> scalar:
42
39
return !op_bits[3] & op_bits[2] & !op_bits[1] & !op_bits[0]
43
40
44
- fn f_x0101(op_bits: vector[8 ]) -> scalar:
41
+ fn f_x0101(op_bits: vector[9 ]) -> scalar:
45
42
return !op_bits[3] & op_bits[2] & !op_bits[1] & op_bits[0]
46
43
47
- fn f_x0110(op_bits: vector[8 ]) -> scalar:
44
+ fn f_x0110(op_bits: vector[9 ]) -> scalar:
48
45
return !op_bits[3] & op_bits[2] & op_bits[1] & !op_bits[0]
49
46
50
- fn f_x0111(op_bits: vector[8 ]) -> scalar:
47
+ fn f_x0111(op_bits: vector[9 ]) -> scalar:
51
48
return !op_bits[3] & op_bits[2] & op_bits[1] & op_bits[0]
52
49
53
- fn f_x1000(op_bits: vector[8 ]) -> scalar:
50
+ fn f_x1000(op_bits: vector[9 ]) -> scalar:
54
51
return op_bits[3] & !op_bits[2] & !op_bits[1] & !op_bits[0]
55
52
56
- fn f_x1001(op_bits: vector[8 ]) -> scalar:
53
+ fn f_x1001(op_bits: vector[9 ]) -> scalar:
57
54
return op_bits[3] & !op_bits[2] & !op_bits[1] & op_bits[0]
58
55
59
- fn f_x1010(op_bits: vector[8 ]) -> scalar:
56
+ fn f_x1010(op_bits: vector[9 ]) -> scalar:
60
57
return op_bits[3] & !op_bits[2] & op_bits[1] & !op_bits[0]
61
58
62
- fn f_x1011(op_bits: vector[8 ]) -> scalar:
59
+ fn f_x1011(op_bits: vector[9 ]) -> scalar:
63
60
return op_bits[3] & !op_bits[2] & op_bits[1] & op_bits[0]
64
61
65
- fn f_x1100(op_bits: vector[8 ]) -> scalar:
62
+ fn f_x1100(op_bits: vector[9 ]) -> scalar:
66
63
return op_bits[3] & op_bits[2] & !op_bits[1] & !op_bits[0]
67
64
68
- fn f_x1101(op_bits: vector[8 ]) -> scalar:
65
+ fn f_x1101(op_bits: vector[9 ]) -> scalar:
69
66
return op_bits[3] & op_bits[2] & !op_bits[1] & op_bits[0]
70
67
71
- fn f_x1110(op_bits: vector[8 ]) -> scalar:
68
+ fn f_x1110(op_bits: vector[9 ]) -> scalar:
72
69
return op_bits[3] & op_bits[2] & op_bits[1] & !op_bits[0]
73
70
74
- fn f_x1111(op_bits: vector[8 ]) -> scalar:
71
+ fn f_x1111(op_bits: vector[9 ]) -> scalar:
75
72
return op_bits[3] & op_bits[2] & op_bits[1] & op_bits[0]
76
73
77
74
78
75
# Composite flags
79
76
80
- fn f_shr(op_bits: vector[8 ]) -> scalar:
77
+ fn f_shr(op_bits: vector[9 ]) -> scalar:
81
78
return !op_bits[6] & op_bits[5] & op_bits[4] + f_u32split(op_bits) + f_push(op_bits)
82
79
83
80
# hasher[5] = op_helpers[3], where hahser[] are decoder columns, which are the same as helper[] -- columns from the stack
84
- fn f_shl(op_bits: vector[8 ], op_helpers: vector[6]) -> scalar:
81
+ fn f_shl(op_bits: vector[9 ], op_helpers: vector[6]) -> scalar:
85
82
let f_add3_mad = op_bits[6] & !op_bits[5] & !op_bits[4] & op_bits[3] & op_bits[2]
86
83
let f_split_loop = op_bits[6] & !op_bits[5] & op_bits[4] & op_bits[3] & op_bits[2]
87
84
return !op_bits[6] & op_bits[5] & !op_bits[4] + f_add3_mad + f_split_loop + f_repeat(op_bits) + f_end(op_bits) * op_helpers[3]
88
85
89
- fn f_ctrl(op_bits: vector[8 ]) -> scalar:
86
+ fn f_ctrl(op_bits: vector[9 ]) -> scalar:
90
87
# flag for SPAN, JOIN, SPLIT, LOOP
91
88
let f_sjsl = op_bits[6] & !op_bits[5] & op_bits[4] & op_bits[3]
92
89
@@ -96,7 +93,7 @@ fn f_ctrl(op_bits: vector[8]) -> scalar:
96
93
return f_sjsl + f_errh + f_call(op_bits) + f_syscall(op_bits)
97
94
98
95
99
- fn compute_op_flags(op_bits: vector[8 ]) -> vector[88]:
96
+ fn compute_op_flags(op_bits: vector[9 ]) -> vector[88]:
100
97
return [
101
98
# No stack shift operations
102
99
f_000(op_bits) & f_x0000(op_bits), # noop
@@ -178,25 +175,27 @@ fn compute_op_flags(op_bits: vector[8]) -> vector[88]:
178
175
f_u32rc(op_bits) & op_bits[3] & op_bits[2] & !op_bits[1], # u32add3
179
176
f_u32rc(op_bits) & op_bits[3] & op_bits[2] & op_bits[1], # u32madd
180
177
178
+
181
179
# High-degree operations
182
- f_101(op_bits) & !op_bits[3] & !op_bits[2] & !op_bits[1], # hperm
183
- f_101(op_bits) & !op_bits[3] & !op_bits[2] & op_bits[1], # mpverify
184
- f_101(op_bits) & !op_bits[3] & op_bits[2] & !op_bits[1], # pipe
185
- f_101(op_bits) & !op_bits[3] & op_bits[2] & op_bits[1], # mstream
186
- f_101(op_bits) & op_bits[3] & !op_bits[2] & !op_bits[1], # span
187
- f_101(op_bits) & op_bits[3] & !op_bits[2] & op_bits[1], # join
188
- f_101(op_bits) & op_bits[3] & op_bits[2] & !op_bits[1], # split
189
- f_101(op_bits) & op_bits[3] & op_bits[2] & op_bits[1], # loop
180
+ op_bits[7] & f_x0000(op_bits), # hperm
181
+ op_bits[7] & f_x0001(op_bits), # mpverify
182
+ op_bits[7] & f_x0010(op_bits), # pipe
183
+ op_bits[7] & f_x0011(op_bits), # mstream
184
+ op_bits[7] & f_x0100(op_bits), # split
185
+ op_bits[7] & f_x0101(op_bits), # loop
186
+ op_bits[7] & f_x0110(op_bits), # span
187
+ op_bits[7] & f_x0111(op_bits), # join
188
+
190
189
191
190
# Very high-degree operations
192
- op_bits[7 ] & !op_bits[4] & !op_bits[3] & !op_bits[2], # mrupdate
193
- op_bits[7 ] & !op_bits[4] & !op_bits[3] & op_bits[2], # push
194
- op_bits[7 ] & !op_bits[4] & op_bits[3] & !op_bits[2], # syscall
195
- op_bits[7 ] & !op_bits[4] & op_bits[3] & op_bits[2], # call
196
- op_bits[7 ] & op_bits[4] & !op_bits[3] & !op_bits[2], # end
197
- op_bits[7 ] & op_bits[4] & !op_bits[3] & op_bits[2], # repeat
198
- op_bits[7 ] & op_bits[4] & op_bits[3] & !op_bits[2], # respan
199
- op_bits[7 ] & op_bits[4] & op_bits[3] & op_bits[2], # halt
191
+ op_bits[8 ] & !op_bits[4] & !op_bits[3] & !op_bits[2], # mrupdate
192
+ op_bits[8 ] & !op_bits[4] & !op_bits[3] & op_bits[2], # push
193
+ op_bits[8 ] & !op_bits[4] & op_bits[3] & !op_bits[2], # syscall
194
+ op_bits[8 ] & !op_bits[4] & op_bits[3] & op_bits[2], # call
195
+ op_bits[8 ] & op_bits[4] & !op_bits[3] & !op_bits[2], # end
196
+ op_bits[8 ] & op_bits[4] & !op_bits[3] & op_bits[2], # repeat
197
+ op_bits[8 ] & op_bits[4] & op_bits[3] & !op_bits[2], # respan
198
+ op_bits[8 ] & op_bits[4] & op_bits[3] & op_bits[2], # halt
200
199
]
201
200
202
201
@@ -215,8 +214,8 @@ ev check_element_validity([op_helpers[6]]):
215
214
216
215
# Enforces that the last bit of the opcode (op_bits[0]) is always set to 0. This evaluator is used
217
216
# for u32 operations where the last bit of the opcode is not used in computation of the flag.
218
- ev b0_is_zero([op_bits[8 ]]):
219
- enf op_bits[6] & !op_bits[5] & op_bits[0] = 0
217
+ ev b0_is_zero([op_bits[9 ]]):
218
+ enf op_bits[6] & !op_bits[5] & !op_bits[4] & op_bits[0] = 0
220
219
221
220
# Enforces that the last two bits of the opcode (op_bits[0] and op_bits[1]) are always set to 0.
222
221
# This evaluator is used for very-high degree operations where the last two bits of the opcode are
@@ -225,13 +224,21 @@ ev b0_b1_is_zero([op_bits[8]]):
225
224
enf op_bits[6] & op_bits[5] & op_bits[0] = 0
226
225
enf op_bits[6] & op_bits[5] & op_bits[1] = 0
227
226
227
+ # Enforces that register extra0 is set to 1 when high-degree operations are executed.
228
+ ev extra0([op_bits[9]]):
229
+ op_bits[7] = 1 when op_bits[6] & !op_bits[5] & op_bits[4]
230
+
231
+ # Enforces that register extra1 is set to 1 when very high-degree operations are executed.
232
+ ev extra1([op_bits[9]]):
233
+ op_bits[8] = 1 when op_bits[6] & op_bits[5]
234
+
228
235
229
236
### Stack Air Constraints #########################################################################
230
237
231
238
# Enforces the constraints on the stack.
232
239
# TODO: add docs for columns
233
240
# stack_helpers consists of [bookkeeping[0], bookkeeping[1], h0]
234
- # op_bits consists of [op_bits[7], extra ]
241
+ # op_bits consists of [op_bits[7], extra0, extra1 ]
235
242
ev stack_constraints([stack_top[16], stack_helpers[3], op_bits[8], op_helpers[6], clk, fmp]):
236
243
let op_flags = compute_op_flags(op_bits)
237
244
@@ -607,57 +614,57 @@ ev clk([s[16], clk]):
607
614
608
615
# u32 operations
609
616
610
- ev u32add([s[16], op_bits[8 ], op_helpers[6]]):
617
+ ev u32add([s[16], op_bits[9 ], op_helpers[6]]):
611
618
enf s[0] + s[1] = 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
612
619
enf s[0]' = op_helpers[2]
613
620
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
614
621
enf s[i]' = s[i] for i in 2..16
615
622
enf b0_is_zero([op_bits])
616
623
617
- ev u32sub([s[16], op_bits[8 ], op_helpers[6]]):
624
+ ev u32sub([s[16], op_bits[9 ], op_helpers[6]]):
618
625
enf s[1] = s[0] + s[1]' + 2^32 * s[0]'
619
626
enf (s[0]')^2 - s[0]' = 0
620
627
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
621
628
enf s[i]' = s[i] for i in 2..16
622
629
enf b0_is_zero([op_bits])
623
630
624
- ev u32mul([s[16], op_bits[8 ], op_helpers[6]]):
631
+ ev u32mul([s[16], op_bits[9 ], op_helpers[6]]):
625
632
enf s[0] * s[1] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
626
633
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
627
634
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
628
635
enf check_element_validity([op_helpers])
629
636
enf s[i]' = s[i] for i in 2..16
630
637
enf b0_is_zero([op_bits])
631
638
632
- ev u32div([s[16], op_bits[8 ], op_helpers[6]]):
639
+ ev u32div([s[16], op_bits[9 ], op_helpers[6]]):
633
640
enf s[1] = s[0] * s[1]' + s[0]'
634
641
enf s[1] - s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
635
642
enf s[0] - s[0]' - 1 = 2^16 * op_helpers[2] + op_helpers[3]
636
643
enf s[i]' = s[i] for i in 2..16
637
644
enf b0_is_zero([op_bits])
638
645
639
- ev u32split([s[16], op_bits[8 ], op_helpers[6]]):
646
+ ev u32split([s[16], op_bits[9 ], op_helpers[6]]):
640
647
enf s[0] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
641
648
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
642
649
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
643
650
enf check_element_validity([op_helpers])
644
651
enf s[i + 1]' = s[i] for i in 1..15
645
652
enf b0_is_zero([op_bits])
646
653
647
- ev u32assert2([s[16], op_bits[8 ], op_helpers[6]]):
654
+ ev u32assert2([s[16], op_bits[9 ], op_helpers[6]]):
648
655
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
649
656
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
650
657
enf s[i]' = s[i] for i in 0..16
651
658
enf b0_is_zero([op_bits])
652
659
653
- ev u32add3([s[16], op_bits[8 ], op_helpers[6]]):
660
+ ev u32add3([s[16], op_bits[9 ], op_helpers[6]]):
654
661
enf s[0] + s[1] + s[2] = 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
655
662
enf s[0]' = op_helpers[2]
656
663
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
657
664
enf s[i]' = s[i + 1] for i in 2..15
658
665
enf b0_is_zero([op_bits])
659
666
660
- ev u32madd([s[16], op_bits[8 ], op_helpers[6]]):
667
+ ev u32madd([s[16], op_bits[9 ], op_helpers[6]]):
661
668
enf s[0] * s[1] + s[2] = 2^48 * op_helpers[3] + 2^32 * op_helpers[2] + 2^16 * op_helpers[1] + op_helpers[0]
662
669
enf s[1]' = 2^16 * op_helpers[1] + op_helpers[0]
663
670
enf s[0]' = 2^16 * op_helpers[3] + op_helpers[2]
@@ -668,28 +675,31 @@ ev u32madd([s[16], op_bits[8], op_helpers[6]]):
668
675
669
676
# High-degree operations
670
677
671
- # Bus constraint is implemented in a separate file
672
- ev hperm([s[16], op_bits[8], op_helpers[6]]):
678
+ ev hperm([s[16], op_bits[9], op_helpers[6]]):
673
679
enf s[i]' = s[i] for i in 12..16
674
- enf b0_is_zero([op_bits])
680
+ enf extra0(op_bits)
681
+ # Bus constraint is implemented in a separate file
675
682
676
- # Bus constraint is implemented in a separate file
677
- ev mpverify([s[16], op_bits[8], op_helpers[6]]):
683
+ ev mpverify([s[16], op_bits[9], op_helpers[6]]):
678
684
enf s[i]' = s[i] for i in 0..16
679
- enf b0_is_zero([op_bits])
685
+ enf extra0(op_bits)
686
+ # Bus constraint is implemented in a separate file
687
+
680
688
681
689
# TODO: add constraints
682
- ev pipe([s[16], op_bits[8 ], op_helpers[6]]):
690
+ ev pipe([s[16], op_bits[9 ], op_helpers[6]]):
683
691
684
692
685
- # Bus constraint is implemented in a separate file
686
- ev mstream([s[16], op_bits[8], op_helpers[6]]):
693
+ ev mstream([s[16], op_bits[9], op_helpers[6]]):
687
694
enf s[12]' = s[12] + 2
688
695
enf s[i]' = s[i] for i in 8..12
689
696
enf s[i]' = s[i] for i in 13..16
697
+ enf extra0(op_bits)
698
+ # Bus constraint is implemented in a separate file
699
+
690
700
691
701
# TODO: add constraints
692
- ev span([s[16], op_bits[8 ], op_helpers[6]])
702
+ ev span([s[16], op_bits[9 ], op_helpers[6]])
693
703
694
704
695
705
# TODO: add constraints
@@ -706,13 +716,18 @@ ev loop()
706
716
707
717
# Very high-degree operations
708
718
709
- # Bus constraint is implemented in a separate file
710
- ev mrupdate([s[16], op_bits[8], op_helpers[6]]):
719
+ ev mrupdate([s[16], op_bits[9], op_helpers[6]]):
711
720
enf s[i]' = s[i] for i in 4..16
712
721
enf b0_b1_is_zero([op_bits])
722
+ enf extra1(op_bits)
723
+ # Bus constraint is implemented in a separate file
724
+
713
725
714
726
ev push([s[16]]):
715
727
enf s[i + 1]' = s[i] for i in 0..15
728
+ enf b0_b1_is_zero([op_bits])
729
+ enf extra1(op_bits)
730
+
716
731
717
732
# TODO: add constraints
718
733
ev syscall():
0 commit comments