|
1 | 1 | use std::collections::HashMap; |
2 | 2 |
|
| 3 | +use apollo_starknet_os_program::OS_PROGRAM; |
3 | 4 | use cairo_vm::hint_processor::builtin_hint_processor::dict_hint_utils::DICT_ACCESS_SIZE; |
4 | 5 | use cairo_vm::types::relocatable::{MaybeRelocatable, Relocatable}; |
5 | 6 | use cairo_vm::vm::vm_core::VirtualMachine; |
| 7 | +use itertools::Itertools; |
| 8 | +use starknet_api::core::{ClassHash, ContractAddress, Nonce, PatriciaKey}; |
| 9 | +use starknet_api::state::StorageKey; |
6 | 10 | use starknet_types_core::felt::Felt; |
7 | 11 |
|
| 12 | +use crate::hints::vars::CairoStruct; |
| 13 | +use crate::io::os_output_types::{FullContractChanges, FullContractStorageUpdate}; |
| 14 | +use crate::vm_utils::get_address_of_nested_fields_from_base_address; |
| 15 | + |
8 | 16 | /// Creates a squashed dict from previous and new values, and stores it in a new memory segment. |
9 | 17 | pub fn allocate_squashed_cairo_dict( |
10 | 18 | prev_values: &HashMap<Felt, MaybeRelocatable>, |
@@ -43,3 +51,129 @@ pub fn parse_squashed_cairo_dict(squashed_dict: &[Felt]) -> HashMap<Felt, Felt> |
43 | 51 | .map(|chunk| (chunk[key_offset], chunk[new_val_offset])) |
44 | 52 | .collect() |
45 | 53 | } |
| 54 | + |
| 55 | +/// Parses a cairo dictionary from VM memory into a squashed dictionary. |
| 56 | +/// Each entry is a tuple of the form (key, (prev_value, new_value)). |
| 57 | +pub fn squash_dict( |
| 58 | + vm: &VirtualMachine, |
| 59 | + dict_start: Relocatable, |
| 60 | + dict_end: Relocatable, |
| 61 | +) -> Vec<(Felt, (Felt, Felt))> { |
| 62 | + let mut prev_vals = HashMap::new(); |
| 63 | + let mut new_vals = HashMap::new(); |
| 64 | + let flat_dict: Vec<MaybeRelocatable> = vm |
| 65 | + .segments |
| 66 | + .memory |
| 67 | + .get_range(dict_start, (dict_end - dict_start).unwrap()) |
| 68 | + .into_iter() |
| 69 | + .map(|item| item.unwrap().into_owned()) |
| 70 | + .collect(); |
| 71 | + for chunk in flat_dict.chunks_exact(DICT_ACCESS_SIZE) { |
| 72 | + let (key, prev, new) = ( |
| 73 | + chunk.first().unwrap().get_int().unwrap(), |
| 74 | + chunk.get(1).unwrap().get_int().unwrap(), |
| 75 | + chunk.get(2).unwrap().get_int().unwrap(), |
| 76 | + ); |
| 77 | + if !prev_vals.contains_key(&prev) { |
| 78 | + prev_vals.insert(key, prev); |
| 79 | + } else { |
| 80 | + assert_eq!(new_vals.get(&key).unwrap(), &prev); |
| 81 | + } |
| 82 | + new_vals.insert(key, new); |
| 83 | + } |
| 84 | + prev_vals.into_iter().map(|(key, prev)| (key, (prev, *new_vals.get(&key).unwrap()))).collect() |
| 85 | +} |
| 86 | + |
| 87 | +/// Parses (from VM memory) a squashed cairo dictionary of contract changes. |
| 88 | +/// Squashes the contract changes per contract address. |
| 89 | +pub fn parse_contract_changes( |
| 90 | + vm: &VirtualMachine, |
| 91 | + dict_start: Relocatable, |
| 92 | + dict_end: Relocatable, |
| 93 | +) -> HashMap<ContractAddress, FullContractChanges> { |
| 94 | + let flat_outer_dict: Vec<MaybeRelocatable> = vm |
| 95 | + .segments |
| 96 | + .memory |
| 97 | + .get_range(dict_start, (dict_end - dict_start).unwrap()) |
| 98 | + .into_iter() |
| 99 | + .map(|item| item.unwrap().into_owned()) |
| 100 | + .collect(); |
| 101 | + assert!(flat_outer_dict.len() % DICT_ACCESS_SIZE == 0, "Invalid outer dict length"); |
| 102 | + flat_outer_dict |
| 103 | + .chunks_exact(DICT_ACCESS_SIZE) |
| 104 | + .map(|chunk| { |
| 105 | + let (address, prev_state_entry_ptr, new_state_entry_ptr) = ( |
| 106 | + ContractAddress( |
| 107 | + PatriciaKey::try_from(chunk.first().unwrap().get_int().unwrap()).unwrap(), |
| 108 | + ), |
| 109 | + chunk.get(1).unwrap().get_relocatable().unwrap(), |
| 110 | + chunk.get(2).unwrap().get_relocatable().unwrap(), |
| 111 | + ); |
| 112 | + |
| 113 | + // Fetch fields of previous and new state entries. |
| 114 | + // Note that nonces and class hash addresses point to integers, while storage pointer |
| 115 | + // points to a relocatable. |
| 116 | + let ( |
| 117 | + prev_nonce_ptr, |
| 118 | + new_nonce_ptr, |
| 119 | + prev_class_hash_ptr, |
| 120 | + new_class_hash_ptr, |
| 121 | + prev_storage_ptr, |
| 122 | + new_storage_ptr, |
| 123 | + ) = [ |
| 124 | + (prev_state_entry_ptr, "nonce"), |
| 125 | + (new_state_entry_ptr, "nonce"), |
| 126 | + (prev_state_entry_ptr, "class_hash"), |
| 127 | + (new_state_entry_ptr, "class_hash"), |
| 128 | + (prev_state_entry_ptr, "storage_ptr"), |
| 129 | + (new_state_entry_ptr, "storage_ptr"), |
| 130 | + ] |
| 131 | + .into_iter() |
| 132 | + .map(|(ptr, field)| { |
| 133 | + get_address_of_nested_fields_from_base_address( |
| 134 | + ptr, |
| 135 | + CairoStruct::StateEntry, |
| 136 | + vm, |
| 137 | + &[field], |
| 138 | + &*OS_PROGRAM, |
| 139 | + ) |
| 140 | + .unwrap() |
| 141 | + }) |
| 142 | + .collect_tuple() |
| 143 | + .unwrap(); |
| 144 | + |
| 145 | + let (prev_nonce, new_nonce, prev_class_hash, new_class_hash) = |
| 146 | + [prev_nonce_ptr, new_nonce_ptr, prev_class_hash_ptr, new_class_hash_ptr] |
| 147 | + .into_iter() |
| 148 | + .map(|ptr| vm.get_integer(ptr).unwrap()) |
| 149 | + .collect_tuple() |
| 150 | + .unwrap(); |
| 151 | + let (prev_storage_ptr, new_storage_ptr) = [prev_storage_ptr, new_storage_ptr] |
| 152 | + .into_iter() |
| 153 | + .map(|ptr| vm.get_relocatable(ptr).unwrap()) |
| 154 | + .collect_tuple() |
| 155 | + .unwrap(); |
| 156 | + |
| 157 | + let storage_changes = squash_dict(vm, prev_storage_ptr, new_storage_ptr) |
| 158 | + .into_iter() |
| 159 | + .map(|(key, (prev_value, new_value))| FullContractStorageUpdate { |
| 160 | + key: StorageKey(PatriciaKey::try_from(key).unwrap()), |
| 161 | + prev_value, |
| 162 | + new_value, |
| 163 | + }) |
| 164 | + .collect(); |
| 165 | + |
| 166 | + ( |
| 167 | + address, |
| 168 | + FullContractChanges { |
| 169 | + addr: address, |
| 170 | + prev_nonce: Nonce(*prev_nonce), |
| 171 | + new_nonce: Nonce(*new_nonce), |
| 172 | + prev_class_hash: ClassHash(*prev_class_hash), |
| 173 | + new_class_hash: ClassHash(*new_class_hash), |
| 174 | + storage_changes, |
| 175 | + }, |
| 176 | + ) |
| 177 | + }) |
| 178 | + .collect() |
| 179 | +} |
0 commit comments