Skip to content

Commit 923fbb5

Browse files
refactor: compute instruction folder
1 parent 15b5cfd commit 923fbb5

File tree

12 files changed

+3174
-0
lines changed

12 files changed

+3174
-0
lines changed
Lines changed: 373 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,373 @@
1+
"""Benchmark arithmetic instructions."""
2+
3+
import operator
4+
import random
5+
6+
import pytest
7+
from ethereum_test_forks import Fork
8+
from ethereum_test_tools import (
9+
Alloc,
10+
BenchmarkTestFiller,
11+
Bytecode,
12+
JumpLoopGenerator,
13+
Transaction,
14+
)
15+
from ethereum_test_vm import Opcodes as Op
16+
17+
from tests.benchmark.compute.helpers import DEFAULT_BINOP_ARGS, make_dup, neg
18+
19+
# Arithmetic instructions:
20+
# ADD, ADDMOD
21+
# SUB, SUBMOD
22+
# MUL, MULMOD
23+
# DIV, SDIV
24+
# MOD, SMOD
25+
# EXP
26+
# SIGNEXTEND
27+
28+
29+
@pytest.mark.parametrize(
30+
"opcode,opcode_args",
31+
[
32+
(
33+
Op.ADD,
34+
DEFAULT_BINOP_ARGS,
35+
),
36+
(
37+
Op.MUL,
38+
DEFAULT_BINOP_ARGS,
39+
),
40+
(
41+
# After every 2 SUB operations, values return to initial.
42+
Op.SUB,
43+
DEFAULT_BINOP_ARGS,
44+
),
45+
(
46+
# This has the cycle of 2:
47+
# v[0] = a // b
48+
# v[1] = a // v[0] = a // (a // b) = b
49+
# v[2] = a // b
50+
Op.DIV,
51+
(
52+
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F,
53+
# We want the first divisor to be slightly bigger than 2**128:
54+
# this is the worst case for the division algorithm with
55+
# optimized paths for division by 1 and 2 words.
56+
0x100000000000000000000000000000033,
57+
),
58+
),
59+
(
60+
# This has the cycle of 2, see above.
61+
Op.DIV,
62+
(
63+
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F,
64+
# We want the first divisor to be slightly bigger than 2**64:
65+
# this is the worst case for the division algorithm with an
66+
# optimized path for division by 1 word.
67+
0x10000000000000033,
68+
),
69+
),
70+
(
71+
# Same as DIV-0
72+
# But the numerator made positive, and the divisor made negative.
73+
Op.SDIV,
74+
(
75+
0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F,
76+
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFCD,
77+
),
78+
),
79+
(
80+
# Same as DIV-1
81+
# But the numerator made positive, and the divisor made negative.
82+
Op.SDIV,
83+
(
84+
0x7FFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F,
85+
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFFFFFFFFFFCD,
86+
),
87+
),
88+
(
89+
# Not suitable for MOD, as values quickly become zero.
90+
Op.MOD,
91+
DEFAULT_BINOP_ARGS,
92+
),
93+
(
94+
# Not suitable for SMOD, as values quickly become zero.
95+
Op.SMOD,
96+
DEFAULT_BINOP_ARGS,
97+
),
98+
(
99+
# This keeps the values unchanged
100+
# pow(2**256-1, 2**256-1, 2**256) == 2**256-1.
101+
Op.EXP,
102+
(
103+
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
104+
0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF,
105+
),
106+
),
107+
(
108+
# Not great, as we always sign-extend the 4 bytes.
109+
Op.SIGNEXTEND,
110+
(
111+
3,
112+
0xFFDADADA, # Negative to have more work.
113+
),
114+
),
115+
],
116+
ids=lambda param: "" if isinstance(param, tuple) else param,
117+
)
118+
def test_arithmetic(
119+
benchmark_test: BenchmarkTestFiller,
120+
opcode: Op,
121+
opcode_args: tuple[int, int],
122+
) -> None:
123+
"""
124+
Benchmark binary instructions (takes two args, pushes one value).
125+
The execution starts with two initial values on the stack
126+
The stack is balanced by the DUP2 instruction.
127+
"""
128+
tx_data = b"".join(
129+
arg.to_bytes(32, byteorder="big") for arg in opcode_args
130+
)
131+
132+
setup = Op.CALLDATALOAD(0) + Op.CALLDATALOAD(32) + Op.DUP2 + Op.DUP2
133+
attack_block = Op.DUP2 + opcode
134+
cleanup = Op.POP + Op.POP + Op.DUP2 + Op.DUP2
135+
benchmark_test(
136+
code_generator=JumpLoopGenerator(
137+
setup=setup,
138+
attack_block=attack_block,
139+
cleanup=cleanup,
140+
tx_kwargs={"data": tx_data},
141+
),
142+
)
143+
144+
145+
@pytest.mark.parametrize("mod_bits", [255, 191, 127, 63])
146+
@pytest.mark.parametrize("op", [Op.MOD, Op.SMOD])
147+
def test_mod(
148+
benchmark_test: BenchmarkTestFiller,
149+
mod_bits: int,
150+
op: Op,
151+
) -> None:
152+
"""
153+
Benchmark MOD instructions.
154+
155+
The program consists of code segments evaluating the "MOD chain":
156+
mod[0] = calldataload(0)
157+
mod[1] = numerators[indexes[0]] % mod[0]
158+
mod[2] = numerators[indexes[1]] % mod[1] ...
159+
160+
The "numerators" is a pool of 15 constants pushed to the EVM stack at the
161+
program start.
162+
163+
The order of accessing the numerators is selected in a way the mod value
164+
remains in the range as long as possible.
165+
"""
166+
# For SMOD we negate both numerator and modulus. The underlying
167+
# computation is the same,
168+
# just the SMOD implementation will have to additionally handle the
169+
# sign bits.
170+
# The result stays negative.
171+
should_negate = op == Op.SMOD
172+
173+
num_numerators = 15
174+
numerator_bits = 256 if not should_negate else 255
175+
numerator_max = 2**numerator_bits - 1
176+
numerator_min = 2 ** (numerator_bits - 1)
177+
178+
# Pick the modulus min value so that it is _unlikely_ to drop to the lower
179+
# word count.
180+
assert mod_bits >= 63
181+
mod_min = 2 ** (mod_bits - 63)
182+
183+
# Select the random seed giving the longest found MOD chain. You can look
184+
# for a longer one by increasing the numerators_min_len. This will activate
185+
# the while loop below.
186+
match op, mod_bits:
187+
case Op.MOD, 255:
188+
seed = 20393
189+
numerators_min_len = 750
190+
case Op.MOD, 191:
191+
seed = 25979
192+
numerators_min_len = 770
193+
case Op.MOD, 127:
194+
seed = 17671
195+
numerators_min_len = 750
196+
case Op.MOD, 63:
197+
seed = 29181
198+
numerators_min_len = 730
199+
case Op.SMOD, 255:
200+
seed = 4015
201+
numerators_min_len = 750
202+
case Op.SMOD, 191:
203+
seed = 17355
204+
numerators_min_len = 750
205+
case Op.SMOD, 127:
206+
seed = 897
207+
numerators_min_len = 750
208+
case Op.SMOD, 63:
209+
seed = 7562
210+
numerators_min_len = 720
211+
case _:
212+
raise ValueError(f"{mod_bits}-bit {op} not supported.")
213+
214+
while True:
215+
rng = random.Random(seed)
216+
217+
# Create the list of random numerators.
218+
numerators = [
219+
rng.randint(numerator_min, numerator_max)
220+
for _ in range(num_numerators)
221+
]
222+
223+
# Create the random initial modulus.
224+
initial_mod = rng.randint(2 ** (mod_bits - 1), 2**mod_bits - 1)
225+
226+
# Evaluate the MOD chain and collect the order of accessing numerators.
227+
mod = initial_mod
228+
indexes = []
229+
while mod >= mod_min:
230+
# Compute results for each numerator.
231+
results = [n % mod for n in numerators]
232+
# And pick the best one.
233+
i = max(range(len(results)), key=results.__getitem__)
234+
mod = results[i]
235+
indexes.append(i)
236+
237+
# Disable if you want to find longer MOD chains.
238+
assert len(indexes) > numerators_min_len
239+
if len(indexes) > numerators_min_len:
240+
break
241+
seed += 1
242+
print(f"{seed=}")
243+
244+
# TODO: Don't use fixed PUSH32. Let Bytecode helpers to select optimal
245+
# push opcode.
246+
setup = sum((Op.PUSH32[n] for n in numerators), Bytecode())
247+
attack_block = (
248+
Op.CALLDATALOAD(0)
249+
+ sum(make_dup(len(numerators) - i) + op for i in indexes)
250+
+ Op.POP
251+
)
252+
253+
input_value = initial_mod if not should_negate else neg(initial_mod)
254+
benchmark_test(
255+
code_generator=JumpLoopGenerator(
256+
setup=setup,
257+
attack_block=attack_block,
258+
tx_kwargs={"data": input_value.to_bytes(32, byteorder="big")},
259+
),
260+
)
261+
262+
263+
@pytest.mark.parametrize("mod_bits", [255, 191, 127, 63])
264+
@pytest.mark.parametrize("op", [Op.ADDMOD, Op.MULMOD])
265+
def test_mod_arithmetic(
266+
benchmark_test: BenchmarkTestFiller,
267+
pre: Alloc,
268+
fork: Fork,
269+
mod_bits: int,
270+
op: Op,
271+
gas_benchmark_value: int,
272+
) -> None:
273+
"""
274+
Benchmark ADDMOD and MULMOD instructions.
275+
276+
The program consists of code segments evaluating the "op chain":
277+
mod[0] = calldataload(0)
278+
mod[1] = (fixed_arg op args[indexes[0]]) % mod[0]
279+
mod[2] = (fixed_arg op args[indexes[1]]) % mod[1]
280+
The "args" is a pool of 15 constants pushed to the EVM stack at the program
281+
start.
282+
The "fixed_arg" is the 0xFF...FF constant added to the EVM stack by PUSH32
283+
just before executing the "op".
284+
The order of accessing the numerators is selected in a way the mod value
285+
remains in the range as long as possible.
286+
"""
287+
fixed_arg = 2**256 - 1
288+
num_args = 15
289+
290+
max_code_size = fork.max_code_size()
291+
292+
# Pick the modulus min value so that it is _unlikely_ to drop to the lower
293+
# word count.
294+
assert mod_bits >= 63
295+
mod_min = 2 ** (mod_bits - 63)
296+
297+
# Select the random seed giving the longest found op chain. You can look
298+
# for a longer one by increasing the op_chain_len. This will activate the
299+
# while loop below.
300+
op_chain_len = 666
301+
match op, mod_bits:
302+
case Op.ADDMOD, 255:
303+
seed = 4
304+
case Op.ADDMOD, 191:
305+
seed = 2
306+
case Op.ADDMOD, 127:
307+
seed = 2
308+
case Op.ADDMOD, 63:
309+
seed = 64
310+
case Op.MULMOD, 255:
311+
seed = 5
312+
case Op.MULMOD, 191:
313+
seed = 389
314+
case Op.MULMOD, 127:
315+
seed = 5
316+
case Op.MULMOD, 63:
317+
# For this setup we were not able to find an op-chain longer than
318+
# 600.
319+
seed = 4193
320+
op_chain_len = 600
321+
case _:
322+
raise ValueError(f"{mod_bits}-bit {op} not supported.")
323+
324+
while True:
325+
rng = random.Random(seed)
326+
args = [rng.randint(2**255, 2**256 - 1) for _ in range(num_args)]
327+
initial_mod = rng.randint(2 ** (mod_bits - 1), 2**mod_bits - 1)
328+
329+
# Evaluate the op chain and collect the order of accessing numerators.
330+
op_fn = operator.add if op == Op.ADDMOD else operator.mul
331+
mod = initial_mod
332+
indexes: list[int] = []
333+
while mod >= mod_min and len(indexes) < op_chain_len:
334+
results = [op_fn(a, fixed_arg) % mod for a in args]
335+
# And pick the best one.
336+
i = max(range(len(results)), key=results.__getitem__)
337+
mod = results[i]
338+
indexes.append(i)
339+
340+
# Disable if you want to find longer op chains.
341+
assert len(indexes) == op_chain_len
342+
if len(indexes) == op_chain_len:
343+
break
344+
seed += 1
345+
print(f"{seed=}")
346+
347+
code_constant_pool = sum((Op.PUSH32[n] for n in args), Bytecode())
348+
code_segment = (
349+
Op.CALLDATALOAD(0)
350+
+ sum(
351+
make_dup(len(args) - i) + Op.PUSH32[fixed_arg] + op
352+
for i in indexes
353+
)
354+
+ Op.POP
355+
)
356+
# Construct the final code. Because of the usage of PUSH32 the code segment
357+
# is very long, so don't try to include multiple of these.
358+
code = (
359+
code_constant_pool
360+
+ Op.JUMPDEST
361+
+ code_segment
362+
+ Op.JUMP(len(code_constant_pool))
363+
)
364+
assert (max_code_size - len(code_segment)) < len(code) <= max_code_size
365+
366+
tx = Transaction(
367+
to=pre.deploy_contract(code=code),
368+
data=initial_mod.to_bytes(32, byteorder="big"),
369+
gas_limit=gas_benchmark_value,
370+
sender=pre.fund_eoa(),
371+
)
372+
373+
benchmark_test(tx=tx)

0 commit comments

Comments
 (0)