|
| 1 | +use std::cell::RefCell; |
| 2 | +use std::collections::{HashMap, HashSet}; |
| 3 | +use std::rc::Rc; |
| 4 | +use std::sync::LazyLock; |
| 5 | + |
| 6 | +use apollo_starknet_os_program::OS_PROGRAM_BYTES; |
| 7 | +use cairo_vm::hint_processor::builtin_hint_processor::dict_manager::DictManager; |
| 8 | +use cairo_vm::hint_processor::builtin_hint_processor::hint_utils::insert_value_into_ap; |
| 9 | +use cairo_vm::types::layout_name::LayoutName; |
| 10 | +use cairo_vm::types::relocatable::MaybeRelocatable; |
| 11 | +use itertools::Itertools; |
| 12 | +use rstest::rstest; |
| 13 | +use starknet_api::core::CONTRACT_ADDRESS_DOMAIN_SIZE; |
| 14 | +use starknet_types_core::felt::Felt; |
| 15 | + |
| 16 | +use crate::test_utils::cairo_dict::parse_contract_changes; |
| 17 | +use crate::test_utils::cairo_runner::{ |
| 18 | + initialize_cairo_runner, |
| 19 | + run_cairo_0_entrypoint, |
| 20 | + EndpointArg, |
| 21 | + EntryPointRunnerConfig, |
| 22 | + ImplicitArg, |
| 23 | + ValueArg, |
| 24 | +}; |
| 25 | + |
| 26 | +const CHANGE_CONTRACT_ENTRY: Felt = CONTRACT_ADDRESS_DOMAIN_SIZE; |
| 27 | +static CHANGE_CLASS_ENTRY: LazyLock<Felt> = LazyLock::new(|| CHANGE_CONTRACT_ENTRY + Felt::ONE); |
| 28 | + |
| 29 | +enum Operation { |
| 30 | + ChangeClass { class_hash: Felt }, |
| 31 | + ChangeContract { contract_address: Felt }, |
| 32 | + StorageWrite { address: Felt, value: Felt }, |
| 33 | +} |
| 34 | + |
| 35 | +impl Operation { |
| 36 | + fn encode(&self) -> [MaybeRelocatable; 2] { |
| 37 | + match self { |
| 38 | + Self::ChangeClass { class_hash } => { |
| 39 | + [Felt::from(*CHANGE_CLASS_ENTRY).into(), class_hash.into()] |
| 40 | + } |
| 41 | + Self::ChangeContract { contract_address } => { |
| 42 | + [CHANGE_CONTRACT_ENTRY.into(), contract_address.into()] |
| 43 | + } |
| 44 | + Self::StorageWrite { address, value } => [address.into(), value.into()], |
| 45 | + } |
| 46 | + } |
| 47 | +} |
| 48 | + |
| 49 | +#[rstest] |
| 50 | +#[case::noop(vec![])] |
| 51 | +#[case::write(vec![Operation::StorageWrite { address: Felt::from(7u8), value: Felt::from(7u8) }])] |
| 52 | +#[case::multiple(vec![ |
| 53 | + Operation::ChangeContract { contract_address: Felt::from(6u8) }, |
| 54 | + Operation::ChangeContract { contract_address: Felt::from(6u8) }, |
| 55 | + Operation::ChangeClass { class_hash: Felt::from(4u8) }, |
| 56 | + Operation::StorageWrite { address: Felt::from(7u8), value: Felt::from(7u8) }, |
| 57 | + Operation::ChangeContract { contract_address: Felt::from(2u8) }, |
| 58 | + Operation::StorageWrite { address: Felt::from(2u8), value: Felt::from(7u8) }, |
| 59 | + Operation::StorageWrite { address: Felt::from(8u8), value: Felt::from(4u8) }, |
| 60 | +])] |
| 61 | +fn test_revert(#[case] test_vector: Vec<Operation>) { |
| 62 | + let initial_contract_address = Felt::from(5u8); |
| 63 | + let initial_class_hash = Felt::ONE; |
| 64 | + let mut current_contract_address = initial_contract_address; |
| 65 | + let mut contract_addresses = HashSet::from([initial_contract_address]); |
| 66 | + let mut expected_storages: HashMap<Felt, HashMap<Felt, Felt>> = HashMap::new(); |
| 67 | + let mut expected_class_hashes = HashMap::new(); |
| 68 | + |
| 69 | + for operation in test_vector.iter().rev() { |
| 70 | + match operation { |
| 71 | + Operation::ChangeClass { class_hash } => { |
| 72 | + expected_class_hashes.insert(current_contract_address, class_hash); |
| 73 | + } |
| 74 | + Operation::ChangeContract { contract_address } => { |
| 75 | + current_contract_address = *contract_address; |
| 76 | + contract_addresses.insert(*contract_address); |
| 77 | + } |
| 78 | + Operation::StorageWrite { address, value } => { |
| 79 | + expected_storages |
| 80 | + .entry(current_contract_address) |
| 81 | + .and_modify(|map| { |
| 82 | + map.insert(*address, *value); |
| 83 | + }) |
| 84 | + .or_insert_with(|| HashMap::from([(*address, *value)])); |
| 85 | + } |
| 86 | + } |
| 87 | + } |
| 88 | + |
| 89 | + // Initialize the runner. |
| 90 | + // Pass no implicits, as the runner initialization only requires the implicit builtins; the |
| 91 | + // implicit state_changes arg is added later. |
| 92 | + let runner_config = EntryPointRunnerConfig { |
| 93 | + trace_enabled: false, |
| 94 | + verify_secure: true, |
| 95 | + layout: LayoutName::starknet, |
| 96 | + proof_mode: false, |
| 97 | + add_main_prefix_to_entrypoint: false, |
| 98 | + }; |
| 99 | + let (mut runner, program, entrypoint) = initialize_cairo_runner( |
| 100 | + &runner_config, |
| 101 | + OS_PROGRAM_BYTES, |
| 102 | + "starkware.starknet.core.os.execution.revert.handle_revert", |
| 103 | + &vec![], |
| 104 | + HashMap::new(), |
| 105 | + ) |
| 106 | + .unwrap(); |
| 107 | + |
| 108 | + // Create the implicit argument (contract state changes) for the runner. |
| 109 | + let state_changes: HashMap<MaybeRelocatable, MaybeRelocatable> = contract_addresses |
| 110 | + .iter() |
| 111 | + .sorted() |
| 112 | + .map(|address| { |
| 113 | + let state_entry: Vec<MaybeRelocatable> = vec![ |
| 114 | + initial_class_hash.into(), |
| 115 | + runner.vm.add_memory_segment().into(), // storage_ptr |
| 116 | + Felt::ZERO.into(), // nonce |
| 117 | + ]; |
| 118 | + (address.into(), runner.vm.gen_arg(&state_entry).unwrap()) |
| 119 | + }) |
| 120 | + .collect(); |
| 121 | + |
| 122 | + // Add the state changes dict to the dict manager. |
| 123 | + let contract_state_changes = if let Ok(dict_manager) = runner.exec_scopes.get_dict_manager() { |
| 124 | + dict_manager.borrow_mut().new_dict(&mut runner.vm, state_changes).unwrap() |
| 125 | + } else { |
| 126 | + let mut dict_manager = DictManager::new(); |
| 127 | + let base = dict_manager.new_dict(&mut runner.vm, state_changes).unwrap(); |
| 128 | + runner.exec_scopes.insert_value("dict_manager", Rc::new(RefCell::new(dict_manager))); |
| 129 | + base |
| 130 | + }; |
| 131 | + insert_value_into_ap(&mut runner.vm, contract_state_changes.clone()).unwrap(); |
| 132 | + |
| 133 | + // Construct the revert log. |
| 134 | + let revert_log: Vec<MaybeRelocatable> = |
| 135 | + Operation::ChangeContract { contract_address: CONTRACT_ADDRESS_DOMAIN_SIZE } |
| 136 | + .encode() |
| 137 | + .into_iter() |
| 138 | + .chain(test_vector.iter().flat_map(|operation| operation.encode().into_iter())) |
| 139 | + .collect(); |
| 140 | + let revert_log_end = |
| 141 | + runner.vm.gen_arg(&revert_log).unwrap().add_int(&revert_log.len().into()).unwrap(); |
| 142 | + |
| 143 | + // Run the entrypoint. |
| 144 | + let explicit_args = vec![ |
| 145 | + EndpointArg::Value(ValueArg::Single(initial_contract_address.into())), |
| 146 | + EndpointArg::Value(ValueArg::Single(revert_log_end.into())), |
| 147 | + ]; |
| 148 | + let implicit_args = vec![ImplicitArg::NonBuiltin(EndpointArg::Value(ValueArg::Single( |
| 149 | + contract_state_changes.clone().into(), |
| 150 | + )))]; |
| 151 | + let state_reader = None; |
| 152 | + let expected_explicit_return_values = vec![]; |
| 153 | + let (implicit_return_values, _explicit_return_values) = run_cairo_0_entrypoint( |
| 154 | + entrypoint, |
| 155 | + &explicit_args, |
| 156 | + &implicit_args, |
| 157 | + state_reader, |
| 158 | + &mut runner, |
| 159 | + &program, |
| 160 | + &runner_config, |
| 161 | + &expected_explicit_return_values, |
| 162 | + ) |
| 163 | + .unwrap(); |
| 164 | + |
| 165 | + // Run the entrypoint and load the resulting contract changes dict. |
| 166 | + let [ |
| 167 | + EndpointArg::Value(ValueArg::Single(MaybeRelocatable::RelocatableValue( |
| 168 | + contract_state_changes_end, |
| 169 | + ))), |
| 170 | + ] = implicit_return_values.as_slice() |
| 171 | + else { |
| 172 | + panic!("Unexpected implicit return values: {implicit_return_values:?}"); |
| 173 | + }; |
| 174 | + let actual_contract_changes = parse_contract_changes( |
| 175 | + &runner.vm, |
| 176 | + contract_state_changes.try_into().unwrap(), |
| 177 | + *contract_state_changes_end, |
| 178 | + ); |
| 179 | + |
| 180 | + // Verify the resulting contract changes dict. |
| 181 | + assert_eq!( |
| 182 | + HashSet::from_iter(actual_contract_changes.keys().map(|address| ***address)), |
| 183 | + contract_addresses |
| 184 | + ); |
| 185 | + for (contract_address, contract_change) in actual_contract_changes.iter() { |
| 186 | + // Iterate over all storage changes for the contract address and verify that each change is |
| 187 | + // as expected. |
| 188 | + let expected_contract_storage = |
| 189 | + expected_storages.remove(contract_address).unwrap_or_default(); |
| 190 | + assert_eq!(contract_change.storage_changes.len(), expected_contract_storage.len()); |
| 191 | + for full_contract_change in contract_change.storage_changes.iter() { |
| 192 | + let expected_value = expected_contract_storage.get(&full_contract_change.key).unwrap(); |
| 193 | + assert_eq!(full_contract_change.prev_value, Felt::ZERO); |
| 194 | + assert_eq!(full_contract_change.new_value, *expected_value); |
| 195 | + // TODO(Dori): If and when we get access to the final state of the hint processor, |
| 196 | + // verify that the current state in the execution helper for this contract address |
| 197 | + // and storage key is as expected. |
| 198 | + } |
| 199 | + |
| 200 | + // Verify class hashes. |
| 201 | + let expected_class_hash = |
| 202 | + expected_class_hashes.get(contract_address).cloned().unwrap_or(&initial_class_hash); |
| 203 | + assert_eq!(contract_change.prev_class_hash.0, initial_class_hash); |
| 204 | + assert_eq!(contract_change.new_class_hash.0, *expected_class_hash); |
| 205 | + } |
| 206 | +} |
0 commit comments