diff --git a/Cargo.toml b/Cargo.toml index 4a2f98a76..819bf9960 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -119,10 +119,12 @@ handlebars = "4.4" # for cheatcode middleware foundry-cheatcodes = { git = "https://github.com/foundry-rs/foundry.git", rev = "dee41819c6e6bd1ea5419c613d226498ed7a2c59" } foundry-abi = { git = "https://github.com/foundry-rs/foundry.git", rev = "dee41819c6e6bd1ea5419c613d226498ed7a2c59" } -alloy-sol-types = "0.4.1" -alloy-dyn-abi = "0.4.1" -alloy-primitives = "0.4.1" - +alloy-sol-types = "0.4" +alloy-dyn-abi = { version = "0.4", features = ["arbitrary", "eip712"] } +alloy-primitives = "0.4" +alloy-json-abi = "0.4" +# error handling +anyhow = "1.0" # logging tracing = "0.1" tracing-subscriber = "0.3" diff --git a/src/evm/abi.rs b/src/evm/abi.rs index 0dd330320..41ac5b934 100644 --- a/src/evm/abi.rs +++ b/src/evm/abi.rs @@ -297,10 +297,15 @@ impl BoxedABI { } pub fn to_colored_string(&self) -> String { - if self.function == [0; 4] { - self.to_string() + if let Some(fn_sig) = self.get_func_signature() { + let fn_name = fn_sig.split('(').next().unwrap(); + let mut args: String = self.b.to_colored_string(); + if args.is_empty() { + args = "()".to_string(); + } + format!("{}{}", fn_name, args) } else { - format!("{}{}", self.get_func_name(), self.b.to_colored_string()) + self.to_string() } } } diff --git a/src/evm/input.rs b/src/evm/input.rs index 3ac1e7527..ec1614c4b 100644 --- a/src/evm/input.rs +++ b/src/evm/input.rs @@ -394,23 +394,19 @@ impl ConciseEVMInput { #[allow(dead_code)] #[inline] fn as_abi_call(&self, call_str: String) -> Option { - let parts: Vec<&str> = call_str.splitn(2, '(').collect(); - if parts.len() < 2 && call_str.len() == 8 { + let selector = self.fn_selector().trim_start_matches("0x").to_string(); + if self.fn_signature().is_empty() || call_str.starts_with(&selector) { return self.as_fn_selector_call(); } + let parts: Vec<&str> = call_str.splitn(2, '(').collect(); let mut fn_call = self.colored_fn_name(parts[0]).to_string(); let value = self.txn_value.unwrap_or_default(); if value != EVMU256::ZERO { fn_call.push_str(&self.colored_value()); } - if parts.len() < 2 { - fn_call.push_str("()"); - } else { - fn_call.push_str(format!("({}", parts[1]).as_str()); - } - + fn_call.push_str(format!("({}", parts[1]).as_str()); Some(format!("{}.{}", colored_address(&self.contract()), fn_call)) } @@ -568,7 +564,7 @@ impl SolutionTx for ConciseEVMInput { } fn value(&self) -> String { - self.txn_value.unwrap_or_default().to_string() + prettify_value(self.txn_value.unwrap_or_default()) } fn is_borrow(&self) -> bool { @@ -582,6 +578,14 @@ impl SolutionTx for ConciseEVMInput { fn swap_data(&self) -> HashMap { self.swap_data.clone() } + + #[cfg(not(feature = "debug"))] + fn calldata(&self) -> String { + match self.data { + Some(ref d) => hex::encode(d.get_bytes()), + None => "".to_string(), + } + } } impl HasLen for EVMInput { diff --git a/src/evm/solution/abi.rs b/src/evm/solution/abi.rs new file mode 100644 index 000000000..92c5dbfa0 --- /dev/null +++ b/src/evm/solution/abi.rs @@ -0,0 +1,325 @@ +use std::{ + collections::{BTreeMap, HashMap}, + fmt::{Display, Formatter}, +}; + +use alloy_dyn_abi::{DynSolType, DynSolValue, JsonAbiExt, ResolveSolType}; +use alloy_json_abi::Function; +use alloy_primitives::{hex, U256}; +use anyhow::Result; +use serde::Serialize; + +use crate::evm::utils::prettify_value; + +#[derive(Debug, Default)] +pub struct Abi { + /// map + /// e.g. struct SomeStruct { uint256 p0; uint256 p1; } + pub struct_defs: Option>, + /// map + /// `struct_nonces` is used to generate unique struct variable names + pub struct_nonces: HashMap, + /// map + /// e.g. SomeStruct memory s1 = SomeStruct(1, 2); + pub struct_instances: Option>, + /// map + pub arrays: Option>, + /// `array_nonce` is used to generate unique array variable names + pub array_nonce: usize, + /// map + pub tuple_struct_names: HashMap, +} + +#[derive(Debug, Default, Serialize, Clone, PartialEq, Eq)] +pub struct StructDef { + pub name: String, + // e.g. uint256 p0 + pub props: Vec, +} + +#[derive(Debug, Default)] +pub struct StructInstance { + pub struct_name: String, + pub var_name: String, + pub value: String, +} + +impl Display for StructInstance { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{} memory {} = {};", self.struct_name, self.var_name, self.value) + } +} + +#[derive(Debug, Default)] +pub struct ArrayInfo { + /// SomeType[] memory a = new SomeType[](2) + pub declaration: String, + /// a[0] = SomeType(...) + /// a[1] = SomeType(...) + pub assignments: Vec, +} + +#[derive(Debug, Default)] +pub struct DecodedArg { + pub ty: String, + pub value: String, +} + +impl DecodedArg { + pub fn new(ty: &str, value: String) -> Self { + Self { + ty: ty.to_string(), + value, + } + } +} + +impl Abi { + pub fn new() -> Self { + Self::default() + } + + pub fn decode_input(&mut self, sig: &str, input: &str) -> Result> { + let func = Function::parse(sig)?; + let calldata = hex::decode(input)?; + let tokens = func.abi_decode_input(&calldata[4..], false)?; + let mut args = vec![]; + for (i, t) in tokens.iter().enumerate() { + let token_type = func.inputs[i].resolve()?; + let arg_type = token_type.sol_type_name().to_string(); + + let arg = self.format_token(t, Some(token_type)); + args.push(DecodedArg::new(&arg_type, arg)); + } + + Ok(args) + } + + pub fn take_struct_defs(&mut self) -> HashMap { + self.struct_defs.take().unwrap_or_default() + } + + pub fn take_memory_vars(&mut self) -> Vec { + let mut vars = Vec::new(); + if let Some(arrays) = self.arrays.take() { + // put all declarations before assignments + let mut arr_assignments = Vec::new(); + for (_, array_info) in arrays { + vars.push(array_info.declaration); + arr_assignments.extend(array_info.assignments); + } + vars.extend(arr_assignments); + } + if let Some(struct_instances) = self.struct_instances.take() { + vars.extend(struct_instances.values().map(|s| s.to_string())); + } + vars + } + + fn format_token(&mut self, token: &DynSolValue, token_type: Option) -> String { + let token_type = token_type.unwrap_or_else(|| token.as_type().expect("unknown type")); + + match token { + DynSolValue::FixedArray(tokens) => self.build_array(token_type, tokens, true), + DynSolValue::Array(tokens) => self.build_array(token_type, tokens, false), + DynSolValue::Tuple(tokens) => { + let struct_name = self.build_tuple_struct_name(token_type); + let prop_names = (0..tokens.len()).map(|i| format!("p{}", i)).collect::>(); + self.build_struct(&struct_name, &prop_names, tokens) + } + DynSolValue::CustomStruct { + name, + prop_names, + tuple, + } => self.build_struct(name, prop_names, tuple), + t => format_token_raw(t), + } + } + + fn build_struct(&mut self, name: &str, prop_names: &[String], tuple: &[DynSolValue]) -> String { + // build struct signature + let types = tuple + .iter() + .map(|t| t.as_type().expect("unknown type").to_string()) + .collect::>() + .join(","); + let signature = format!("({})", types); + + // build struct definition + let props = self.build_struct_props(tuple, prop_names); + let struct_def = StructDef { + name: name.to_string(), + props, + }; + self.struct_defs + .get_or_insert_with(HashMap::new) + .insert(signature.clone(), struct_def); + + // build struct instance + self.struct_nonces + .entry(signature.clone()) + .and_modify(|nonce| *nonce += 1) + .or_insert(0); + let var_name = format!("{}{}", name.to_lowercase(), self.struct_nonces[&signature]); + let struct_instance = StructInstance { + struct_name: name.to_string(), + var_name: var_name.clone(), + value: format!("{}({})", name, self.build_struct_args(tuple)), + }; + self.struct_instances + .get_or_insert_with(BTreeMap::new) + .insert(var_name.clone(), struct_instance); + + var_name + } + + fn build_struct_props(&mut self, tuple: &[DynSolValue], prop_names: &[String]) -> Vec { + tuple + .iter() + .enumerate() + .map(|(i, t)| { + let prop_name = prop_names[i].to_string(); + let prop_type = t.as_type().expect("unknown type"); + format!("{} {}", prop_type, prop_name) + }) + .collect::>() + } + + fn build_struct_args(&mut self, tuple: &[DynSolValue]) -> String { + tuple + .iter() + .map(|t| self.format_token(t, None)) + .collect::>() + .join(", ") + } + + fn build_tuple_struct_name(&mut self, tuple: DynSolType) -> String { + let signature = tuple.sol_type_name().to_string(); + let tuple_len = self.tuple_struct_names.len(); + let struct_name = self + .tuple_struct_names + .entry(signature) + .or_insert(format!("S{}", tuple_len)); + struct_name.to_string() + } + + fn build_array(&mut self, array_type: DynSolType, tokens: &[DynSolValue], is_fixed: bool) -> String { + let array_name = self.build_array_name(); + let array_info = self.build_array_info(array_type, &array_name, tokens, is_fixed); + self.arrays + .get_or_insert_with(BTreeMap::new) + .insert(array_name.clone(), array_info); + array_name + } + + fn build_array_name(&mut self) -> String { + let name = format!("arr{}", self.array_nonce); + self.array_nonce += 1; + name + } + + fn build_array_info( + &mut self, + array_type: DynSolType, + array_name: &str, + tokens: &[DynSolValue], + is_fixed: bool, + ) -> ArrayInfo { + let assignments = tokens + .iter() + .enumerate() + .map(|(i, t)| { + format!( + "{}[{}] = {};", + array_name, + i, + self.format_array_item(t, array_type.clone()) + ) + }) + .collect::>(); + + let ty = self.format_array_type(&array_type); + let array_len = tokens.len(); + let declaration = if is_fixed { + format!("{} memory {};", ty, array_name) + } else { + format!("{} memory {} = new {}({});", ty, array_name, ty, array_len) + }; + + ArrayInfo { + declaration, + assignments, + } + } + + fn format_array_type(&self, array_type: &DynSolType) -> String { + let mut ty = array_type.sol_type_name().to_string(); + // change tuple array type to struct array type + if ty.starts_with('(') && ty.ends_with("[]") { + let tuple_sig = ty.trim_end_matches("[]"); + let struct_name = self.struct_defs.as_ref().unwrap()[tuple_sig].name.clone(); + ty = format!("{}[]", struct_name); + } + + ty + } + + fn format_array_item(&mut self, token: &DynSolValue, token_type: DynSolType) -> String { + match token { + DynSolValue::Tuple(..) | DynSolValue::CustomStruct { .. } => { + let var_name = self.format_token(token, Some(token_type)); + let struct_instance = self + .struct_instances + .get_or_insert_with(BTreeMap::new) + .remove(&var_name) + .expect("struct instance not found"); + + struct_instance.value + } + _ => self.format_token(token, Some(token_type)), + } + } +} + +pub fn format_token_raw(token: &DynSolValue) -> String { + match token { + DynSolValue::Address(addr) => addr.to_checksum(None), + DynSolValue::FixedBytes(bytes, _) => { + if bytes.is_empty() { + String::from("\"\"") + } else { + hex::encode_prefixed(bytes) + } + } + DynSolValue::Bytes(bytes) => { + if bytes.is_empty() { + String::from("\"\"") + } else { + format!("hex\"{}\"", hex::encode(bytes)) + } + } + DynSolValue::Int(num, _) => num.to_string(), + DynSolValue::Uint(num, _) => { + if num == &U256::MAX { + String::from("type(uint256).max") + } else { + prettify_value(*num) + } + } + DynSolValue::Bool(b) => b.to_string(), + DynSolValue::String(s) => format!("\"{s}\""), + DynSolValue::FixedArray(tokens) => format!("[{}]", format_array(tokens)), + DynSolValue::Array(tokens) => format!("[{}]", format_array(tokens)), + DynSolValue::Tuple(tokens) => format!("({})", format_array(tokens)), + DynSolValue::CustomStruct { + name: _, + prop_names: _, + tuple, + } => format!("({})", format_array(tuple)), + DynSolValue::Function(f) => f.to_address_and_selector().1.to_string(), + } +} + +fn format_array(tokens: &[DynSolValue]) -> String { + tokens.iter().map(format_token_raw).collect::>().join(", ") +} diff --git a/src/evm/solution/foundry_test.hbs b/src/evm/solution/foundry_test.hbs index d1d6561f7..abad62fbf 100644 --- a/src/evm/solution/foundry_test.hbs +++ b/src/evm/solution/foundry_test.hbs @@ -17,6 +17,16 @@ import "forge-std/Test.sol"; */ contract {{contract_name}} is Test { +{{#if struct_defs}} +{{#each struct_defs}} + struct {{this.name}} { + {{#each this.props}} + {{{this}}}; + {{/each}} + } +{{/each}} + +{{/if}} function setUp() public { {{#if is_onchain}} vm.createSelectFork("{{chain}}", {{block_number}}); @@ -31,12 +41,14 @@ contract {{contract_name}} is Test { {{#each trace}} vm.prank({{caller}}); {{#with this}} - {{#if raw_code}} - {{raw_code}} + {{#if interface_calls}} + {{#each interface_calls}} + {{{this}}} + {{/each}} {{else}} {{#if (is_deposit buy_type)}} vm.deal({{caller}}, {{value}}); - {{#with (lookup swap_data "deposit")~}}{{target}}.call{value: {{../value}}}(abi.encodeWithSignature("deposit()"));{{/with}} + {{#with (lookup swap_data "deposit")~}}I({{target}}).deposit{value: {{../value}}}();{{/with}} {{else}} {{#if (is_buy buy_type)}} address[] memory path{{borrow_idx}} = new address[](); @@ -53,22 +65,16 @@ contract {{contract_name}} is Test { {{#if value}} vm.deal({{caller}}, {{value}}); {{/if}} - {{#if fn_signature}} - {{contract}}.call{{#if value}}{value: {{value}}}{{/if}}(abi.encodeWithSignature( - "{{fn_signature}}"{{#if fn_args}}, {{{fn_args}}}{{/if}} - )); - {{else}} {{contract}}.call{{#if value}}{value: {{value}}}{{/if}}(abi.encodeWithSelector( - {{fn_selector}}{{#if fn_args}},{{{fn_args}}}{{/if}} + {{fn_selector}}{{#if fn_args}}, {{{fn_args}}}{{/if}} )); {{/if}} {{/if}} {{/if}} - {{/if}} {{#if (is_withdraw sell_type)}} vm.startPrank({{caller}}); uint256 amount{{balance_idx}} = IERC20({{contract}}).balanceOf(address(this)); - {{#with (lookup swap_data "withdraw")~}}{{target}}.call(abi.encodeWithSignature("withdraw(uint256)", amount{{../balance_idx}}));{{/with}} + {{#with (lookup swap_data "withdraw")~}}I({{target}}).withdraw(amount{{../balance_idx}});{{/with}} vm.stopPrank(); {{else}} {{#if (is_sell sell_type)}} @@ -97,8 +103,16 @@ contract {{contract_name}} is Test { receive() external payable {} {{/if}} } +{{#if interface}} +interface I { +{{#each interface}} + {{{this}}} +{{/each}} +} +{{/if}} {{#if include_interface}} + interface IERC20 { function balanceOf(address owner) external view returns (uint256); function approve(address spender, uint256 value) external returns (bool); diff --git a/src/evm/solution/mod.rs b/src/evm/solution/mod.rs index 3112bbec7..4923354c0 100644 --- a/src/evm/solution/mod.rs +++ b/src/evm/solution/mod.rs @@ -1,15 +1,19 @@ +mod abi; + use std::{ - collections::HashMap, + collections::{HashMap, HashSet}, fs::{self, File}, path::Path, sync::OnceLock, time::SystemTime, }; +use abi::StructDef; use handlebars::{handlebars_helper, Handlebars}; use serde::Serialize; use tracing::{debug, error}; +use self::abi::{Abi, DecodedArg}; use super::{types::EVMU256, utils, OnChainConfig}; use crate::{generic_vm::vm_state::SwapInfo, input::SolutionTx}; @@ -105,7 +109,7 @@ struct CliArgs { #[derive(Debug, Serialize, Default)] pub struct Tx { - raw_code: String, + interface_calls: Vec, // A tx can contain both a `buy` and a `sell` operation at the same time. buy_type: BuyType, sell_type: SellType, @@ -116,6 +120,7 @@ pub struct Tx { fn_signature: String, fn_selector: String, fn_args: String, + calldata: String, liq_percent: u8, balance_idx: u32, // map @@ -143,6 +148,7 @@ impl From<&T> for Tx { fn_signature: input.fn_signature(), fn_selector: input.fn_selector(), fn_args: input.fn_args(), + calldata: input.calldata(), liq_percent, swap_data, ..Default::default() @@ -150,9 +156,36 @@ impl From<&T> for Tx { } } +impl Tx { + pub fn make_interface_call(&self, args: &[DecodedArg]) -> String { + let fn_name = self + .fn_signature + .split('(') + .next() + .unwrap_or(&self.fn_selector) + .to_string(); + + let args = args + .iter() + .map(|DecodedArg { value, .. }| value.clone()) + .collect::>() + .join(", "); + + let mut fn_call = format!("I({}).{}", self.contract, fn_name); + if !self.value.is_empty() { + fn_call.push_str(format!("{{value: {}}}", self.value).as_str()); + } + fn_call.push_str(format!("({});", args).as_str()); + fn_call + } +} + #[derive(Debug, Serialize, Default)] pub struct TemplateArgs { contract_name: String, + // map + struct_defs: HashMap, + interface: HashSet, is_onchain: bool, include_interface: bool, router: String, @@ -183,10 +216,19 @@ impl TemplateArgs { let contract_name = make_contract_name(cli_args); let include_interface = trace .iter() - .any(|x| !x.raw_code.is_empty() || x.buy_type == BuyType::Buy || x.sell_type == SellType::Sell); + .any(|x| !x.interface_calls.is_empty() || x.buy_type == BuyType::Buy || x.sell_type == SellType::Sell); + + // Decode calldata and collect struct definitions and interface declarations + let mut struct_defs = HashMap::new(); + let mut interface = HashSet::new(); + for tx in trace.iter_mut() { + decode_calldata(tx, &mut struct_defs, &mut interface); + } Ok(Self { contract_name, + struct_defs, + interface, is_onchain: cli_args.is_onchain, include_interface, router, @@ -202,6 +244,56 @@ impl TemplateArgs { } } +fn decode_calldata(tx: &mut Tx, struct_defs: &mut HashMap, interface: &mut HashSet) { + // Ignore preset interfaces (IERC20 / IUniswapV2Router) + if !tx.interface_calls.is_empty() || tx.buy_type == BuyType::Buy || tx.sell_type == SellType::Sell { + return; + } + + if tx.buy_type == BuyType::Deposit || tx.sell_type == SellType::Withdraw { + if tx.buy_type == BuyType::Deposit { + interface.insert("function deposit() external payable;".to_string()); + } + if tx.sell_type == SellType::Withdraw { + interface.insert("function withdraw(uint256) external;".to_string()); + } + return; + } + + let mut abi = Abi::new(); + if let Ok(args) = abi.decode_input(&tx.fn_signature, &tx.calldata) { + let sd = abi.take_struct_defs(); + struct_defs.extend(sd.clone()); + interface.insert(make_interface(&tx.fn_signature, &sd)); + + tx.interface_calls.extend(abi.take_memory_vars()); + tx.interface_calls.push(tx.make_interface_call(args.as_slice())); + } +} + +fn make_interface(fn_sig: &str, struct_defs: &HashMap) -> String { + let mut fn_decl = fn_sig.to_string(); + // struct_signature: "(address,address,uint24)" + for (struct_signature, struct_def) in struct_defs.iter() { + let struct_arr_sig = format!("{}[]", struct_signature); + if fn_decl.contains(&struct_arr_sig) { + fn_decl = fn_decl.replace(&struct_arr_sig, &format!("{}[] memory", struct_def.name)); + } else { + fn_decl = fn_decl.replace(struct_signature, &format!("{} memory", struct_def.name)); + } + } + + let res = fn_decl + .replace("bytes,", "bytes memory,") + .replace("bytes)", "bytes memory)") + .replace("string,", "string memory,") + .replace("string)", "string memory)") + .replace("],", "] memory,") + .replace("])", "] memory)"); + + format!("function {} external payable;", res) +} + fn setup_trace(trace: &mut [Tx]) { let (mut borrow_idx, mut balance_idx) = (0, 0); for tx in trace.iter_mut() { @@ -212,8 +304,8 @@ fn setup_trace(trace: &mut [Tx]) { } // Raw code - if let Some(code) = make_raw_code(tx) { - tx.raw_code = code; + if let Some(call) = make_erc20_calls(tx) { + tx.interface_calls = vec![call]; continue; } @@ -230,7 +322,7 @@ fn setup_trace(trace: &mut [Tx]) { } } -fn make_raw_code(tx: &Tx) -> Option { +fn make_erc20_calls(tx: &Tx) -> Option { if tx.buy_type != BuyType::None { return None; } @@ -371,3 +463,98 @@ impl SellType { SellType::None } } + +#[cfg(test)] +mod tests { + use super::*; + + struct MockInput { + caller: String, + contract: String, + fn_signature: String, + fn_selector: String, + fn_args: String, + value: String, + is_borrow: bool, + liq_percent: u8, + swap_data: HashMap, + calldata: String, + } + + impl MockInput { + fn new(fn_sig: &str, calldata: &str, value: &str) -> Self { + Self { + caller: String::from(""), + contract: String::from("0xca143ce32fe78f1f7019d7d551a6402fc5350c73"), + fn_signature: String::from(fn_sig), + fn_selector: String::from(""), + fn_args: String::from(""), + value: String::from(value), + is_borrow: false, + liq_percent: 0, + swap_data: HashMap::new(), + calldata: String::from(calldata), + } + } + } + + impl SolutionTx for MockInput { + fn caller(&self) -> String { + self.caller.clone() + } + fn contract(&self) -> String { + self.contract.clone() + } + fn fn_signature(&self) -> String { + self.fn_signature.clone() + } + fn fn_selector(&self) -> String { + self.fn_selector.clone() + } + fn fn_args(&self) -> String { + self.fn_args.clone() + } + fn value(&self) -> String { + self.value.clone() + } + fn is_borrow(&self) -> bool { + self.is_borrow + } + fn liq_percent(&self) -> u8 { + self.liq_percent + } + fn swap_data(&self) -> HashMap { + self.swap_data.clone() + } + fn calldata(&self) -> String { + self.calldata.clone() + } + } + + #[test] + fn test_template_is_valid() { + let mut handlebars = Handlebars::new(); + assert!(handlebars.register_template_string("foundry_test", TEMPLATE).is_ok()); + } + + #[test] + fn test_generate_test() { + let target = "0xca143ce32fe78f1f7019d7d551a6402fc5350c73".to_string(); + let work_dir = "/tmp".to_string(); + init_cli_args(target, work_dir, &None); + + let input1 = MockInput::new( + "approve(address,uint256)", + "0x095ea7b300000000000000000000000089257a52ad585aacb1137fcc8abbd03a963b96830000000000000000000000000000000000000000000000056bc75e2d63100000", + "", + ); + let input2 = MockInput::new( + "createPair(address,address)", + "c9c65396000000000000000000000000aaec620ab3a0aa4e503c544d8715d70082da7891000000000000000000000000bb4cdb9cbd36b01bd1cbaebf2de08d9173bc095c", + "1234", + ); + let inputs = vec![input1, input2]; + let solution = String::from("solution"); + generate_test(solution, inputs); + } +} diff --git a/src/input.rs b/src/input.rs index d9788925e..3803c3acf 100644 --- a/src/input.rs +++ b/src/input.rs @@ -147,4 +147,7 @@ pub trait SolutionTx { fn swap_data(&self) -> HashMap { HashMap::new() } + fn calldata(&self) -> String { + String::from("") + } }