Skip to content

Commit 295bdd6

Browse files
starknet_os: migrate test_revert
1 parent 308d88f commit 295bdd6

File tree

2 files changed

+204
-0
lines changed

2 files changed

+204
-0
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
pub(crate) mod implementation;
2+
#[cfg(test)]
3+
mod test;
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
use std::cell::RefCell;
2+
use std::collections::{HashMap, HashSet};
3+
use std::rc::Rc;
4+
use std::sync::LazyLock;
5+
6+
use apollo_starknet_os_program::OS_PROGRAM_BYTES;
7+
use cairo_vm::hint_processor::builtin_hint_processor::dict_manager::DictManager;
8+
use cairo_vm::hint_processor::builtin_hint_processor::hint_utils::insert_value_into_ap;
9+
use cairo_vm::types::layout_name::LayoutName;
10+
use cairo_vm::types::relocatable::MaybeRelocatable;
11+
use itertools::Itertools;
12+
use rstest::rstest;
13+
use starknet_api::core::CONTRACT_ADDRESS_DOMAIN_SIZE;
14+
use starknet_types_core::felt::Felt;
15+
16+
use crate::test_utils::cairo_dict::parse_contract_changes;
17+
use crate::test_utils::cairo_runner::{
18+
initialize_cairo_runner,
19+
run_cairo_0_entrypoint,
20+
EndpointArg,
21+
EntryPointRunnerConfig,
22+
ImplicitArg,
23+
ValueArg,
24+
};
25+
26+
const CHANGE_CONTRACT_ENTRY: Felt = CONTRACT_ADDRESS_DOMAIN_SIZE;
27+
static CHANGE_CLASS_ENTRY: LazyLock<Felt> = LazyLock::new(|| CHANGE_CONTRACT_ENTRY + Felt::ONE);
28+
29+
enum Operation {
30+
ChangeClass { class_hash: Felt },
31+
ChangeContract { contract_address: Felt },
32+
StorageWrite { address: Felt, value: Felt },
33+
}
34+
35+
impl Operation {
36+
fn encode(&self) -> [MaybeRelocatable; 2] {
37+
match self {
38+
Self::ChangeClass { class_hash } => [(*CHANGE_CLASS_ENTRY).into(), class_hash.into()],
39+
Self::ChangeContract { contract_address } => {
40+
[CHANGE_CONTRACT_ENTRY.into(), contract_address.into()]
41+
}
42+
Self::StorageWrite { address, value } => [address.into(), value.into()],
43+
}
44+
}
45+
}
46+
47+
#[rstest]
48+
#[case::noop(vec![])]
49+
#[case::write(vec![Operation::StorageWrite { address: Felt::from(7u8), value: Felt::from(7u8) }])]
50+
#[case::multiple(vec![
51+
Operation::ChangeContract { contract_address: Felt::from(6u8) },
52+
Operation::ChangeContract { contract_address: Felt::from(6u8) },
53+
Operation::ChangeClass { class_hash: Felt::from(4u8) },
54+
Operation::StorageWrite { address: Felt::from(7u8), value: Felt::from(7u8) },
55+
Operation::ChangeContract { contract_address: Felt::from(2u8) },
56+
Operation::StorageWrite { address: Felt::from(2u8), value: Felt::from(7u8) },
57+
Operation::StorageWrite { address: Felt::from(8u8), value: Felt::from(4u8) },
58+
])]
59+
fn test_revert(#[case] test_vector: Vec<Operation>) {
60+
let initial_contract_address = Felt::from(5u8);
61+
let initial_class_hash = Felt::ONE;
62+
let mut current_contract_address = initial_contract_address;
63+
let mut contract_addresses = HashSet::from([initial_contract_address]);
64+
let mut expected_storages: HashMap<Felt, HashMap<Felt, Felt>> = HashMap::new();
65+
let mut expected_class_hashes = HashMap::new();
66+
67+
for operation in test_vector.iter().rev() {
68+
match operation {
69+
Operation::ChangeClass { class_hash } => {
70+
expected_class_hashes.insert(current_contract_address, class_hash);
71+
}
72+
Operation::ChangeContract { contract_address } => {
73+
current_contract_address = *contract_address;
74+
contract_addresses.insert(*contract_address);
75+
}
76+
Operation::StorageWrite { address, value } => {
77+
expected_storages
78+
.entry(current_contract_address)
79+
.or_default()
80+
.insert(*address, *value);
81+
}
82+
}
83+
}
84+
85+
// Initialize the runner.
86+
// Pass no implicits, as the runner initialization only requires the implicit builtins; the
87+
// implicit state_changes arg is added later.
88+
let runner_config = EntryPointRunnerConfig {
89+
trace_enabled: false,
90+
verify_secure: true,
91+
layout: LayoutName::starknet,
92+
proof_mode: false,
93+
add_main_prefix_to_entrypoint: false,
94+
};
95+
let (mut runner, program, entrypoint) = initialize_cairo_runner(
96+
&runner_config,
97+
OS_PROGRAM_BYTES,
98+
"starkware.starknet.core.os.execution.revert.handle_revert",
99+
&[],
100+
HashMap::new(),
101+
)
102+
.unwrap();
103+
104+
// Create the implicit argument (contract state changes) for the runner.
105+
let state_changes: HashMap<MaybeRelocatable, MaybeRelocatable> = contract_addresses
106+
.iter()
107+
.sorted()
108+
.map(|address| {
109+
let state_entry: Vec<MaybeRelocatable> = vec![
110+
initial_class_hash.into(),
111+
runner.vm.add_memory_segment().into(), // storage_ptr
112+
Felt::ZERO.into(), // nonce
113+
];
114+
(address.into(), runner.vm.gen_arg(&state_entry).unwrap())
115+
})
116+
.collect();
117+
118+
// Add the state changes dict to the dict manager.
119+
let contract_state_changes = if let Ok(dict_manager) = runner.exec_scopes.get_dict_manager() {
120+
dict_manager.borrow_mut().new_dict(&mut runner.vm, state_changes).unwrap()
121+
} else {
122+
let mut dict_manager = DictManager::new();
123+
let base = dict_manager.new_dict(&mut runner.vm, state_changes).unwrap();
124+
runner.exec_scopes.insert_value("dict_manager", Rc::new(RefCell::new(dict_manager)));
125+
base
126+
};
127+
insert_value_into_ap(&mut runner.vm, contract_state_changes.clone()).unwrap();
128+
129+
// Construct the revert log.
130+
let revert_log: Vec<MaybeRelocatable> =
131+
Operation::ChangeContract { contract_address: CONTRACT_ADDRESS_DOMAIN_SIZE }
132+
.encode()
133+
.into_iter()
134+
.chain(test_vector.iter().flat_map(|operation| operation.encode().into_iter()))
135+
.collect();
136+
let revert_log_end =
137+
runner.vm.gen_arg(&revert_log).unwrap().add_int(&revert_log.len().into()).unwrap();
138+
139+
// Run the entrypoint.
140+
let explicit_args = vec![
141+
EndpointArg::Value(ValueArg::Single(initial_contract_address.into())),
142+
EndpointArg::Value(ValueArg::Single(revert_log_end)),
143+
];
144+
let implicit_args = vec![ImplicitArg::NonBuiltin(EndpointArg::Value(ValueArg::Single(
145+
contract_state_changes.clone(),
146+
)))];
147+
let state_reader = None;
148+
let expected_explicit_return_values = vec![];
149+
let (implicit_return_values, _explicit_return_values) = run_cairo_0_entrypoint(
150+
entrypoint,
151+
&explicit_args,
152+
&implicit_args,
153+
state_reader,
154+
&mut runner,
155+
&program,
156+
&runner_config,
157+
&expected_explicit_return_values,
158+
)
159+
.unwrap();
160+
161+
// Run the entrypoint and load the resulting contract changes dict.
162+
let [
163+
EndpointArg::Value(ValueArg::Single(MaybeRelocatable::RelocatableValue(
164+
contract_state_changes_end,
165+
))),
166+
] = implicit_return_values.as_slice()
167+
else {
168+
panic!("Unexpected implicit return values: {implicit_return_values:?}");
169+
};
170+
let actual_contract_changes = parse_contract_changes(
171+
&runner.vm,
172+
contract_state_changes.try_into().unwrap(),
173+
*contract_state_changes_end,
174+
);
175+
176+
// Verify the resulting contract changes dict.
177+
assert_eq!(
178+
HashSet::from_iter(actual_contract_changes.keys().map(|address| ***address)),
179+
contract_addresses
180+
);
181+
for (contract_address, contract_change) in actual_contract_changes.iter() {
182+
// Iterate over all storage changes for the contract address and verify that each change is
183+
// as expected.
184+
let expected_contract_storage =
185+
expected_storages.remove(contract_address).unwrap_or_default();
186+
assert_eq!(contract_change.storage_changes.len(), expected_contract_storage.len());
187+
for full_contract_change in contract_change.storage_changes.iter() {
188+
let expected_value = expected_contract_storage.get(&full_contract_change.key).unwrap();
189+
assert_eq!(full_contract_change.prev_value, Felt::ZERO);
190+
assert_eq!(full_contract_change.new_value, *expected_value);
191+
// TODO(Dori): If and when we get access to the final state of the hint processor,
192+
// verify that the current state in the execution helper for this contract address
193+
// and storage key is as expected.
194+
}
195+
196+
// Verify class hashes.
197+
let expected_class_hash =
198+
expected_class_hashes.get(contract_address).cloned().unwrap_or(&initial_class_hash);
199+
assert_eq!(contract_change.prev_class_hash.0, initial_class_hash);
200+
assert_eq!(contract_change.new_class_hash.0, *expected_class_hash);
201+
}
202+
}

0 commit comments

Comments
 (0)