diff --git a/foundry.toml b/foundry.toml index 27f3dbb..6cf4b64 100644 --- a/foundry.toml +++ b/foundry.toml @@ -6,6 +6,8 @@ via_ir = true solc_version = "0.8.30" optimizer = true optimizer_runs = 20_000 +bytecode_hash = "none" +cbor_metadata = false [dependencies] diff --git a/src/interfaces/IERC7579Account.sol b/src/interfaces/IERC7579Account.sol new file mode 100644 index 0000000..e1b0d9f --- /dev/null +++ b/src/interfaces/IERC7579Account.sol @@ -0,0 +1,106 @@ +// SPDX-License-Identifier: MIT +pragma solidity ^0.8.21; + +struct Execution { + address target; + uint256 value; + bytes callData; +} + +interface IERC7579Account { + event ModuleInstalled(uint256 moduleTypeId, address module); + event ModuleUninstalled(uint256 moduleTypeId, address module); + + /** + * @dev Executes a transaction on behalf of the account. + * This function is intended to be called by ERC-4337 EntryPoint.sol + * @dev Ensure adequate authorization control: i.e. onlyEntryPointOrSelf + * + * @dev MSA MUST implement this function signature. + * If a mode is requested that is not supported by the Account, it MUST revert + * @param mode The encoded execution mode of the transaction. See ModeLib.sol for details + * @param executionCalldata The encoded execution call data + */ + function execute(bytes32 mode, bytes calldata executionCalldata) external payable; + + /** + * @dev Executes a transaction on behalf of the account. + * This function is intended to be called by Executor Modules + * @dev Ensure adequate authorization control: i.e. onlyExecutorModule + * + * @dev MSA MUST implement this function signature. + * If a mode is requested that is not supported by the Account, it MUST revert + * @param mode The encoded execution mode of the transaction. See ModeLib.sol for details + * @param executionCalldata The encoded execution call data + */ + function executeFromExecutor(bytes32 mode, bytes calldata executionCalldata) + external + payable + returns (bytes[] memory returnData); + + /** + * @dev ERC-1271 isValidSignature + * This function is intended to be used to validate a smart account signature + * and may forward the call to a validator module + * + * @param hash The hash of the data that is signed + * @param data The data that is signed + */ + function isValidSignature(bytes32 hash, bytes calldata data) external view returns (bytes4); + + /** + * @dev installs a Module of a certain type on the smart account + * @dev Implement Authorization control of your choosing + * @param moduleTypeId the module type ID according the ERC-7579 spec + * @param module the module address + * @param initData arbitrary data that may be required on the module during `onInstall` + * initialization. + */ + function installModule(uint256 moduleTypeId, address module, bytes calldata initData) external payable; + + /** + * @dev uninstalls a Module of a certain type on the smart account + * @dev Implement Authorization control of your choosing + * @param moduleTypeId the module type ID according the ERC-7579 spec + * @param module the module address + * @param deInitData arbitrary data that may be required on the module during `onUninstall` + * de-initialization. + */ + function uninstallModule(uint256 moduleTypeId, address module, bytes calldata deInitData) external payable; + + /** + * Function to check if the account supports a certain CallType or ExecType (see ModeLib.sol) + * @param encodedMode the encoded mode + */ + function supportsExecutionMode(bytes32 encodedMode) external view returns (bool); + + /** + * Function to check if the account supports installation of a certain module type Id + * @param moduleTypeId the module type ID according the ERC-7579 spec + */ + function supportsModule(uint256 moduleTypeId) external view returns (bool); + + /** + * Function to check if the account has a certain module installed + * @param moduleTypeId the module type ID according the ERC-7579 spec + * Note: keep in mind that some contracts can be multiple module types at the same time. It + * thus may be necessary to query multiple module types + * @param module the module address + * @param additionalContext additional context data that the smart account may interpret to + * identify conditions under which the module is installed. + * usually this is not necessary, but for some special hooks that + * are stored in mappings, this param might be needed + */ + function isModuleInstalled(uint256 moduleTypeId, address module, bytes calldata additionalContext) + external + view + returns (bool); + + /** + * @dev Returns the account id of the smart account + * @return accountImplementationId the account id of the smart account + * the accountId should be structured like so: + * "vendorname.accountname.semver" + */ + function accountId() external view returns (string memory accountImplementationId); +} diff --git a/src/policies/PerChainPolicy.sol b/src/policies/PerChainPolicy.sol new file mode 100644 index 0000000..f9ef9d3 --- /dev/null +++ b/src/policies/PerChainPolicy.sol @@ -0,0 +1,155 @@ +pragma solidity ^0.8.0; + +struct ChainPolicyArgs { + uint256[] chainIds; + bytes24[] callAddrAndSelector; +} + +struct ChainPolicyConfig { + bool check; + bytes24[] callAddrAndSelector; +} +import {PolicyBase} from "src/base/PolicyBase.sol"; +import {PackedUserOperation} from "account-abstraction/interfaces/PackedUserOperation.sol"; +import {IERC7579Account} from "src/interfaces/IERC7579Account.sol"; +import {IAccountExecute} from "account-abstraction/interfaces/IAccountExecute.sol"; +import {ExecMode, CallType, ExecType} from "src/types/Types.sol"; +import {LibERC7579} from "solady/accounts/LibERC7579.sol"; +import {CALLTYPE_SINGLE, CALLTYPE_BATCH} from "src/types/Constants.sol"; + +contract PerChainPolicy is PolicyBase { + error CallViolatesParamRule(); + error NotSupported(); + error AlreadyTaken(); + error InvalidId(); + mapping(bytes32 id => mapping(address account => ChainPolicyConfig)) public config; + mapping(bytes32 id => mapping(address account => bool)) public configured; + mapping(address account => uint256) public usedIds; + + function isInitialized(address account) external view override returns (bool) { + return usedIds[account] > 0; + } + + function _policyOninstall(bytes32 id, bytes calldata data) internal override { + if(configured[id][msg.sender]) { + revert AlreadyTaken(); + } + configured[id][msg.sender] = true; + usedIds[msg.sender]++; + bytes1 mode = data[0]; + if (mode == bytes1(0)) { + // check only on given chains + ChainPolicyArgs calldata args; + assembly { + args := add(data.offset, 1) + } + _installMode0(id, args); + } else if (mode == bytes1(0x01)) { + // check on all other chains than given chains + ChainPolicyArgs calldata args; + assembly { + args := add(data.offset, 1) + } + _installMode1(id, args); + } else { + revert NotSupported(); + } + } + + function _policyOnUninstall(bytes32 id, bytes calldata) internal override { + if(!configured[id][msg.sender]) { + revert InvalidId(); + } + usedIds[msg.sender]--; + delete config[id][msg.sender]; + } + + function _installMode0(bytes32 id, ChainPolicyArgs calldata args) internal { + for (uint256 i = 0; i < args.chainIds.length; i++) { + if (args.chainIds[i] == block.chainid) { + config[id][msg.sender] = ChainPolicyConfig({check: true, callAddrAndSelector: args.callAddrAndSelector}); + return; + } + } + // if not found, don't check + } + + function _installMode1(bytes32 id, ChainPolicyArgs calldata args) internal { + for (uint256 i = 0; i < args.chainIds.length; i++) { + if (args.chainIds[i] == block.chainid) { + // if found, don't check + return; + } + } + config[id][msg.sender] = ChainPolicyConfig({check: true, callAddrAndSelector: args.callAddrAndSelector}); + } + + function checkUserOpPolicy(bytes32 id, PackedUserOperation calldata userOp) + external + payable + override + returns (uint256) + { + ChainPolicyConfig storage thisConfig = config[id][msg.sender]; + if (thisConfig.check) { + _checkCallData(userOp.callData, thisConfig.callAddrAndSelector); + } + return 0; + } + + function checkSignaturePolicy(bytes32 id, address, bytes32, bytes calldata) + external + view + override + returns (uint256) + { + ChainPolicyConfig memory thisConfig = config[id][msg.sender]; + if (thisConfig.check) { + // this is not allowed for this id + return 1; + } + return 0; + } + + function _checkCallData(bytes calldata callData, bytes24[] storage callAddrAndSelector) internal { + if (bytes4(callData[0:4]) == IAccountExecute.executeUserOp.selector) { + callData = callData[4:]; + } + require(bytes4(callData[0:4]) == IERC7579Account.execute.selector); + bytes32 mode = bytes32(callData[4:36]); + bytes1 callType = LibERC7579.getCallType(mode); + bytes calldata executionCallData = callData[36:]; + if (callType == CALLTYPE_SINGLE) { + (address target, uint256 value, bytes calldata cd) = LibERC7579.decodeSingle(executionCallData); + bool permissionPass = _checkPermission(target, cd, value, callAddrAndSelector); + if (!permissionPass) { + revert CallViolatesParamRule(); + } + } else if (callType == CALLTYPE_BATCH) { + bytes32[] calldata pointers = LibERC7579.decodeBatch(executionCallData); + for (uint256 i = 0; i < pointers.length; i++) { + (address target, uint256 value, bytes calldata cd) = LibERC7579.getExecution(pointers, i); + bool permissionPass = _checkPermission(target, cd, value, callAddrAndSelector); + if (!permissionPass) { + revert CallViolatesParamRule(); + } + } + } else { + revert NotSupported(); + } + } + + function _checkPermission(address target, bytes calldata data, uint256, bytes24[] storage allowed) + internal + returns (bool) + { + for (uint256 i = 0; i < allowed.length; i++) { + address t = address(bytes20(allowed[i])); + bytes4 selector = bytes4(uint32(uint192(allowed[i]))); + if (target == t && bytes4(data[0:4]) == selector) { + return true; + } + } + return false; + } +} diff --git a/src/policies/SignaturePolicy.sol b/src/policies/SignaturePolicy.sol index 6b3d11d..45d19c1 100644 --- a/src/policies/SignaturePolicy.sol +++ b/src/policies/SignaturePolicy.sol @@ -25,15 +25,20 @@ contract SignaturePolicy is PolicyBase, IStatelessValidatorWithSender { mapping(bytes32 id => mapping(address => Status)) public status; mapping(bytes32 id => mapping(address caller => mapping(address wallet => bool))) public allowedCaller; - function isModuleType(uint256 typeID) external pure override(IModule,PolicyBase) returns (bool) { + function isModuleType(uint256 typeID) external pure override(IModule, PolicyBase) returns (bool) { return typeID == MODULE_TYPE_POLICY || typeID == MODULE_TYPE_STATELESS_VALIDATOR_WITH_SENDER; } - function isInitialized(address wallet) external view override(IModule,PolicyBase) returns (bool) { + function isInitialized(address wallet) external view override(IModule, PolicyBase) returns (bool) { return usedIds[wallet] > 0; } - function checkUserOpPolicy(bytes32 id, PackedUserOperation calldata userOp) external payable override returns (uint256) { + function checkUserOpPolicy(bytes32 id, PackedUserOperation calldata userOp) + external + payable + override + returns (uint256) + { return _validateUserOpPolicy(id, msg.sender); } diff --git a/src/policies/TimelockPolicy.sol b/src/policies/TimelockPolicy.sol index afa6835..fdd1af1 100644 --- a/src/policies/TimelockPolicy.sol +++ b/src/policies/TimelockPolicy.sol @@ -420,12 +420,10 @@ contract TimelockPolicy is PolicyBase, IStatelessValidator, IStatelessValidatorW * @notice Internal function to validate user operation policy * @dev Shared logic for both installed and stateless validator modes */ - function _validateUserOpPolicy( - bytes32 id, - PackedUserOperation calldata userOp, - bytes calldata sig, - address account - ) internal returns (uint256) { + function _validateUserOpPolicy(bytes32 id, PackedUserOperation calldata userOp, bytes calldata sig, address account) + internal + returns (uint256) + { TimelockConfig storage config = timelockConfig[id][account]; if (!config.initialized) return SIG_VALIDATION_FAILED_UINT; diff --git a/src/signers/ECDSASigner.sol b/src/signers/ECDSASigner.sol index 79ac662..376c125 100644 --- a/src/signers/ECDSASigner.sol +++ b/src/signers/ECDSASigner.sol @@ -43,10 +43,9 @@ contract ECDSASigner is SignerBase, IStatelessValidator, IStatelessValidatorWith returns (uint256) { address owner = signer[id][msg.sender]; - return - _verifySignature(userOpHash, userOp.signature, owner) - ? SIG_VALIDATION_SUCCESS_UINT - : SIG_VALIDATION_FAILED_UINT; + return _verifySignature(userOpHash, userOp.signature, owner) + ? SIG_VALIDATION_SUCCESS_UINT + : SIG_VALIDATION_FAILED_UINT; } function checkSignature(bytes32 id, address sender, bytes32 hash, bytes calldata sig) diff --git a/src/signers/WeightedECDSASigner.sol b/src/signers/WeightedECDSASigner.sol index 11cbe16..ddabe20 100644 --- a/src/signers/WeightedECDSASigner.sol +++ b/src/signers/WeightedECDSASigner.sol @@ -118,12 +118,12 @@ contract WeightedECDSASigner is EIP712, SignerBase, IStatelessValidator, IStatel return _validateStatelessSignature(hash, signature, guardians, weights, threshold); } - function validateSignatureWithDataWithSender( - address, - bytes32 hash, - bytes calldata signature, - bytes calldata data - ) external view override(IStatelessValidatorWithSender) returns (bool) { + function validateSignatureWithDataWithSender(address, bytes32 hash, bytes calldata signature, bytes calldata data) + external + view + override(IStatelessValidatorWithSender) + returns (bool) + { (address[] memory guardians, uint24[] memory weights, uint24 threshold) = abi.decode(data, (address[], uint24[], uint24)); return _validateStatelessSignature(hash, signature, guardians, weights, threshold); diff --git a/src/types/Constants.sol b/src/types/Constants.sol index ebeae9e..ce788b5 100644 --- a/src/types/Constants.sol +++ b/src/types/Constants.sol @@ -9,6 +9,9 @@ uint256 constant MODULE_TYPE_SIGNER = 6; uint256 constant MODULE_TYPE_STATELESS_VALIDATOR = 7; uint256 constant MODULE_TYPE_STATELESS_VALIDATOR_WITH_SENDER = 10; +bytes1 constant CALLTYPE_SINGLE = 0x00; +bytes1 constant CALLTYPE_BATCH = 0x01; + bytes4 constant ERC1271_MAGICVALUE = 0x1626ba7e; bytes4 constant ERC1271_INVALID = 0xffffffff; uint256 constant SIG_VALIDATION_FAILED_UINT = 1; diff --git a/src/types/Types.sol b/src/types/Types.sol index 6891a25..3a4b3a6 100644 --- a/src/types/Types.sol +++ b/src/types/Types.sol @@ -5,3 +5,15 @@ type ValidUntil is uint48; function packValidationData(ValidAfter validAfter, ValidUntil validUntil) pure returns (uint256) { return uint256(ValidAfter.unwrap(validAfter)) << 208 | uint256(ValidUntil.unwrap(validUntil)) << 160; } + +// Custom type for improved developer experience +type ExecMode is bytes32; + +type CallType is bytes1; + +type ExecType is bytes1; + +type ExecModeSelector is bytes4; + +type ExecModePayload is bytes22; + diff --git a/src/validators/ECDSAValidator.sol b/src/validators/ECDSAValidator.sol index 4b02c31..37f04b5 100644 --- a/src/validators/ECDSAValidator.sol +++ b/src/validators/ECDSAValidator.sol @@ -71,10 +71,9 @@ contract ECDSAValidator is IValidator, IHook, IStatelessValidator, IStatelessVal returns (uint256) { address owner = ecdsaValidatorStorage[msg.sender].owner; - return - _verifySignature(userOpHash, userOp.signature, owner) - ? SIG_VALIDATION_SUCCESS_UINT - : SIG_VALIDATION_FAILED_UINT; + return _verifySignature(userOpHash, userOp.signature, owner) + ? SIG_VALIDATION_SUCCESS_UINT + : SIG_VALIDATION_FAILED_UINT; } function isValidSignatureWithSender(address, bytes32 hash, bytes calldata sig) diff --git a/test/ECDSASigner.t.sol b/test/ECDSASigner.t.sol index 6ff2d95..6d95f66 100644 --- a/test/ECDSASigner.t.sol +++ b/test/ECDSASigner.t.sol @@ -8,13 +8,10 @@ import {PackedUserOperation} from "account-abstraction/interfaces/PackedUserOper import {IModule} from "src/interfaces/IERC7579Modules.sol"; import "forge-std/console.sol"; -contract ECDSASignerTest is - SignerTestBase, - StatelessValidatorTestBase, - StatelessValidatorWithSenderTestBase -{ +contract ECDSASignerTest is SignerTestBase, StatelessValidatorTestBase, StatelessValidatorWithSenderTestBase { address owner; uint256 ownerKey; + function deployModule() internal virtual override returns (IModule) { return new ECDSASigner(); } @@ -27,10 +24,13 @@ contract ECDSASignerTest is return abi.encodePacked(owner); } - function userOpSignature( - PackedUserOperation memory userOp, - bool valid - ) internal view virtual override returns (bytes memory) { + function userOpSignature(PackedUserOperation memory userOp, bool valid) + internal + view + virtual + override + returns (bytes memory) + { console.log("owner:", owner); bytes32 hash = ENTRYPOINT.getUserOpHash(userOp); if (!valid) { @@ -40,10 +40,7 @@ contract ECDSASignerTest is return abi.encodePacked(r, s, v); } - function erc1271Signature( - bytes32 hash, - bool valid - ) internal view virtual override returns (address,bytes memory) { + function erc1271Signature(bytes32 hash, bool valid) internal view virtual override returns (address, bytes memory) { if (!valid) { hash = keccak256(abi.encodePacked("invalid", hash)); } @@ -51,17 +48,23 @@ contract ECDSASignerTest is return (address(0), abi.encodePacked(r, s, v)); } - function statelessValidationSignature( - bytes32 hash, - bool valid - ) internal view virtual override returns (address, bytes memory) { + function statelessValidationSignature(bytes32 hash, bool valid) + internal + view + virtual + override + returns (address, bytes memory) + { return erc1271Signature(hash, valid); } - function statelessValidationSignatureWithSender( - bytes32 hash, - bool valid - ) internal view virtual override returns (address, bytes memory) { + function statelessValidationSignatureWithSender(bytes32 hash, bool valid) + internal + view + virtual + override + returns (address, bytes memory) + { return erc1271Signature(hash, valid); } } diff --git a/test/ECDSAValidator.t.sol b/test/ECDSAValidator.t.sol index 23765d5..8bc9390 100644 --- a/test/ECDSAValidator.t.sol +++ b/test/ECDSAValidator.t.sol @@ -9,11 +9,7 @@ import {IModule, IHook} from "src/interfaces/IERC7579Modules.sol"; import {MODULE_TYPE_HOOK} from "src/types/Constants.sol"; import "forge-std/console.sol"; -contract ECDSAValidatorTest is - ValidatorTestBase, - StatelessValidatorTestBase, - StatelessValidatorWithSenderTestBase -{ +contract ECDSAValidatorTest is ValidatorTestBase, StatelessValidatorTestBase, StatelessValidatorWithSenderTestBase { address owner; uint256 ownerKey; @@ -29,10 +25,13 @@ contract ECDSAValidatorTest is return abi.encodePacked(owner); } - function userOpSignature( - PackedUserOperation memory userOp, - bool valid - ) internal view virtual override returns (bytes memory) { + function userOpSignature(PackedUserOperation memory userOp, bool valid) + internal + view + virtual + override + returns (bytes memory) + { bytes32 hash = ENTRYPOINT.getUserOpHash(userOp); if (!valid) { hash = keccak256(abi.encodePacked("invalid", hash)); @@ -41,10 +40,13 @@ contract ECDSAValidatorTest is return abi.encodePacked(r, s, v); } - function erc1271Signature( - bytes32 hash, - bool valid - ) internal view virtual override returns (address sender, bytes memory signature) { + function erc1271Signature(bytes32 hash, bool valid) + internal + view + virtual + override + returns (address sender, bytes memory signature) + { if (!valid) { hash = keccak256(abi.encodePacked("invalid", hash)); } @@ -52,10 +54,13 @@ contract ECDSAValidatorTest is return (address(0), abi.encodePacked(r, s, v)); } - function statelessValidationSignature( - bytes32 hash, - bool valid - ) internal view virtual override returns (address, bytes memory) { + function statelessValidationSignature(bytes32 hash, bool valid) + internal + view + virtual + override + returns (address, bytes memory) + { if (!valid) { hash = keccak256(abi.encodePacked("invalid", hash)); } @@ -63,10 +68,13 @@ contract ECDSAValidatorTest is return (address(0), abi.encodePacked(r, s, v)); } - function statelessValidationSignatureWithSender( - bytes32 hash, - bool valid - ) internal view virtual override returns (address, bytes memory) { + function statelessValidationSignatureWithSender(bytes32 hash, bool valid) + internal + view + virtual + override + returns (address, bytes memory) + { return statelessValidationSignature(hash, valid); } diff --git a/test/SignaturePolicy.t.sol b/test/SignaturePolicy.t.sol index f256ff6..9a5ee40 100644 --- a/test/SignaturePolicy.t.sol +++ b/test/SignaturePolicy.t.sol @@ -8,10 +8,7 @@ import {PackedUserOperation} from "account-abstraction/interfaces/PackedUserOper import {IModule} from "src/interfaces/IERC7579Modules.sol"; import "forge-std/console.sol"; -contract SignaturePolicyTest is - PolicyTestBase, - StatelessValidatorWithSenderTestBase -{ +contract SignaturePolicyTest is PolicyTestBase, StatelessValidatorWithSenderTestBase { address allowedCaller; address disallowedCaller; @@ -30,59 +27,39 @@ contract SignaturePolicyTest is return abi.encode(callers); } - function validUserOp() - internal - view - virtual - override - returns (PackedUserOperation memory) - { + function validUserOp() internal view virtual override returns (PackedUserOperation memory) { // SignaturePolicy always passes for live policies in checkUserOpPolicy - return - PackedUserOperation({ - sender: WALLET, - nonce: 0, - initCode: "", - callData: "", - accountGasLimits: bytes32( - abi.encodePacked(uint128(100000), uint128(200000)) - ), - preVerificationGas: 0, - gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), - paymasterAndData: "", - signature: "" - }); + return PackedUserOperation({ + sender: WALLET, + nonce: 0, + initCode: "", + callData: "", + accountGasLimits: bytes32(abi.encodePacked(uint128(100000), uint128(200000))), + preVerificationGas: 0, + gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), + paymasterAndData: "", + signature: "" + }); } - function invalidUserOp() - internal - view - virtual - override - returns (PackedUserOperation memory) - { + function invalidUserOp() internal view virtual override returns (PackedUserOperation memory) { // For SignaturePolicy, userOp validation always passes if policy is live // To make it fail, we would need to use a non-live policy, but that's tested separately // For this test, we'll just return a userOp (the fail case is tested by not installing) - return - PackedUserOperation({ - sender: address(0xDEAD), // Different sender to simulate non-installed policy - nonce: 0, - initCode: "", - callData: "", - accountGasLimits: bytes32( - abi.encodePacked(uint128(100000), uint128(200000)) - ), - preVerificationGas: 0, - gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), - paymasterAndData: "", - signature: "" - }); + return PackedUserOperation({ + sender: address(0xDEAD), // Different sender to simulate non-installed policy + nonce: 0, + initCode: "", + callData: "", + accountGasLimits: bytes32(abi.encodePacked(uint128(100000), uint128(200000))), + preVerificationGas: 0, + gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), + paymasterAndData: "", + signature: "" + }); } - function validSignatureData( - bytes32 hash - ) + function validSignatureData(bytes32 hash) internal view virtual @@ -93,9 +70,7 @@ contract SignaturePolicyTest is return (allowedCaller, ""); } - function invalidSignatureData( - bytes32 hash - ) + function invalidSignatureData(bytes32 hash) internal view virtual @@ -108,11 +83,7 @@ contract SignaturePolicyTest is // Override the fail test because SignaturePolicy's checkUserOpPolicy doesn't validate based on userOp content // It only checks if the policy is live for the calling account - function testPolicyAfterInstallCheckUserOpPolicyFail() - public - payable - override - { + function testPolicyAfterInstallCheckUserOpPolicyFail() public payable override { SignaturePolicy policyModule = SignaturePolicy(address(module)); // Don't install for this account @@ -121,18 +92,18 @@ contract SignaturePolicyTest is PackedUserOperation memory userOp = validUserOp(); vm.startPrank(nonInstalledAccount); - uint256 validationResult = policyModule.checkUserOpPolicy( - policyId(), - userOp - ); + uint256 validationResult = policyModule.checkUserOpPolicy(policyId(), userOp); vm.stopPrank(); assertFalse(validationResult == 0); } - function statelessValidationSignatureWithSender( - bytes32 hash, - bool valid - ) internal view virtual override returns (address, bytes memory) { + function statelessValidationSignatureWithSender(bytes32 hash, bool valid) + internal + view + virtual + override + returns (address, bytes memory) + { return valid ? validSignatureData(hash) : invalidSignatureData(hash); } } diff --git a/test/TimelockPolicy.t.sol b/test/TimelockPolicy.t.sol index c88f914..55d7541 100644 --- a/test/TimelockPolicy.t.sol +++ b/test/TimelockPolicy.t.sol @@ -8,11 +8,7 @@ import {PackedUserOperation} from "account-abstraction/interfaces/PackedUserOper import {IModule, IStatelessValidator, IStatelessValidatorWithSender} from "src/interfaces/IERC7579Modules.sol"; import "forge-std/console.sol"; -contract TimelockPolicyTest is - PolicyTestBase, - StatelessValidatorTestBase, - StatelessValidatorWithSenderTestBase -{ +contract TimelockPolicyTest is PolicyTestBase, StatelessValidatorTestBase, StatelessValidatorWithSenderTestBase { uint48 delay = 1 days; uint48 expirationPeriod = 1 days; @@ -26,52 +22,34 @@ contract TimelockPolicyTest is return abi.encode(delay, expirationPeriod); } - function validUserOp() - internal - view - virtual - override - returns (PackedUserOperation memory) - { + function validUserOp() internal view virtual override returns (PackedUserOperation memory) { // For a valid userOp execution, we need a proposal that has been created and timelock has passed - return - PackedUserOperation({ - sender: WALLET, - nonce: 1, - initCode: "", - callData: hex"1234", // Some calldata for the proposal - accountGasLimits: bytes32( - abi.encodePacked(uint128(100000), uint128(200000)) - ), - preVerificationGas: 0, - gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), - paymasterAndData: "", - signature: "" - }); + return PackedUserOperation({ + sender: WALLET, + nonce: 1, + initCode: "", + callData: hex"1234", // Some calldata for the proposal + accountGasLimits: bytes32(abi.encodePacked(uint128(100000), uint128(200000))), + preVerificationGas: 0, + gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), + paymasterAndData: "", + signature: "" + }); } - function invalidUserOp() - internal - view - virtual - override - returns (PackedUserOperation memory) - { + function invalidUserOp() internal view virtual override returns (PackedUserOperation memory) { // An invalid userOp would be one without a proposal - return - PackedUserOperation({ - sender: WALLET, - nonce: 999, // No proposal created for this nonce - initCode: "", - callData: hex"abcd", - accountGasLimits: bytes32( - abi.encodePacked(uint128(100000), uint128(200000)) - ), - preVerificationGas: 0, - gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), - paymasterAndData: "", - signature: "" - }); + return PackedUserOperation({ + sender: WALLET, + nonce: 999, // No proposal created for this nonce + initCode: "", + callData: hex"abcd", + accountGasLimits: bytes32(abi.encodePacked(uint128(100000), uint128(200000))), + preVerificationGas: 0, + gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), + paymasterAndData: "", + signature: "" + }); } function validSignatureData( @@ -101,17 +79,31 @@ contract TimelockPolicyTest is } function statelessValidationSignature( - bytes32 /* hash */, + bytes32, + /* hash */ bool valid - ) internal view virtual override returns (address, bytes memory signature) { + ) + internal + view + virtual + override + returns (address, bytes memory signature) + { // Signature doesn't matter for TimelockPolicy return (address(0), ""); } function statelessValidationSignatureWithSender( - bytes32 /* hash */, + bytes32, + /* hash */ bool valid - ) internal view virtual override returns (address, bytes memory) { + ) + internal + view + virtual + override + returns (address, bytes memory) + { return statelessValidationSignature(bytes32(0), valid); } @@ -149,11 +141,7 @@ contract TimelockPolicyTest is } // Override the checkUserOpPolicy tests because TimelockPolicy has special behavior - function testPolicyAfterInstallCheckUserOpPolicySuccess() - public - payable - override - { + function testPolicyAfterInstallCheckUserOpPolicySuccess() public payable override { TimelockPolicy policyModule = TimelockPolicy(address(module)); vm.startPrank(WALLET); policyModule.onInstall(abi.encodePacked(policyId(), installData())); @@ -163,12 +151,7 @@ contract TimelockPolicyTest is // First create a proposal vm.startPrank(WALLET); - policyModule.createProposal( - policyId(), - WALLET, - userOp.callData, - userOp.nonce - ); + policyModule.createProposal(policyId(), WALLET, userOp.callData, userOp.nonce); vm.stopPrank(); // Fast forward past the delay @@ -176,10 +159,7 @@ contract TimelockPolicyTest is // Now execute the proposal vm.startPrank(WALLET); - uint256 validationResult = policyModule.checkUserOpPolicy( - policyId(), - userOp - ); + uint256 validationResult = policyModule.checkUserOpPolicy(policyId(), userOp); vm.stopPrank(); // For TimelockPolicy, successful execution returns packed validation data with timelock info @@ -187,11 +167,7 @@ contract TimelockPolicyTest is assertFalse(validationResult == 1); } - function testPolicyAfterInstallCheckUserOpPolicyFail() - public - payable - override - { + function testPolicyAfterInstallCheckUserOpPolicyFail() public payable override { TimelockPolicy policyModule = TimelockPolicy(address(module)); vm.startPrank(WALLET); policyModule.onInstall(abi.encodePacked(policyId(), installData())); @@ -201,10 +177,7 @@ contract TimelockPolicyTest is // Try to execute without creating a proposal first vm.startPrank(WALLET); - uint256 validationResult = policyModule.checkUserOpPolicy( - policyId(), - userOp - ); + uint256 validationResult = policyModule.checkUserOpPolicy(policyId(), userOp); vm.stopPrank(); // Should fail (return 1 = SIG_VALIDATION_FAILED_UINT) @@ -222,12 +195,7 @@ contract TimelockPolicyTest is (address sender, bytes memory sigData) = invalidSignatureData(testHash); vm.startPrank(nonInstalledWallet); - uint256 result = policyModule.checkSignaturePolicy( - policyId(), - sender, - testHash, - sigData - ); + uint256 result = policyModule.checkSignaturePolicy(policyId(), sender, testHash, sigData); vm.stopPrank(); // Should fail for non-installed account @@ -250,11 +218,8 @@ contract TimelockPolicyTest is vm.stopPrank(); // Verify proposal was created - ( - TimelockPolicy.ProposalStatus status, - uint256 validAfter, - uint256 validUntil - ) = policyModule.getProposal(WALLET, callData, nonce, policyId(), WALLET); + (TimelockPolicy.ProposalStatus status, uint256 validAfter, uint256 validUntil) = + policyModule.getProposal(WALLET, callData, nonce, policyId(), WALLET); assertEq(uint256(status), uint256(TimelockPolicy.ProposalStatus.Pending)); assertEq(validAfter, block.timestamp + delay); @@ -281,13 +246,7 @@ contract TimelockPolicyTest is vm.stopPrank(); // Verify proposal was cancelled - (TimelockPolicy.ProposalStatus status, , ) = policyModule.getProposal( - WALLET, - callData, - nonce, - policyId(), - WALLET - ); + (TimelockPolicy.ProposalStatus status,,) = policyModule.getProposal(WALLET, callData, nonce, policyId(), WALLET); assertEq(uint256(status), uint256(TimelockPolicy.ProposalStatus.Cancelled)); } @@ -305,8 +264,8 @@ contract TimelockPolicyTest is // Encode proposal data in signature bytes memory signature = abi.encodePacked( uint256(proposalCallData.length), // callDataLength - proposalCallData, // callData - proposalNonce // nonce + proposalCallData, // callData + proposalNonce // nonce ); PackedUserOperation memory userOp = PackedUserOperation({ @@ -314,9 +273,7 @@ contract TimelockPolicyTest is nonce: 0, initCode: "", callData: "", // Empty calldata = no-op - accountGasLimits: bytes32( - abi.encodePacked(uint128(100000), uint128(200000)) - ), + accountGasLimits: bytes32(abi.encodePacked(uint128(100000), uint128(200000))), preVerificationGas: 0, gasFees: bytes32(abi.encodePacked(uint128(1), uint128(1))), paymasterAndData: "", @@ -331,13 +288,8 @@ contract TimelockPolicyTest is assertEq(result, 1); // Verify proposal was created - (TimelockPolicy.ProposalStatus status, , ) = policyModule.getProposal( - WALLET, - proposalCallData, - proposalNonce, - policyId(), - WALLET - ); + (TimelockPolicy.ProposalStatus status,,) = + policyModule.getProposal(WALLET, proposalCallData, proposalNonce, policyId(), WALLET); assertEq(uint256(status), uint256(TimelockPolicy.ProposalStatus.Pending)); } diff --git a/test/WeightedECDSASigner.t.sol b/test/WeightedECDSASigner.t.sol index 4229407..b2eaa09 100644 --- a/test/WeightedECDSASigner.t.sol +++ b/test/WeightedECDSASigner.t.sol @@ -9,11 +9,7 @@ import {IModule} from "src/interfaces/IERC7579Modules.sol"; import {ECDSA} from "solady/utils/ECDSA.sol"; import "forge-std/console.sol"; -contract WeightedECDSASignerTest is - SignerTestBase, - StatelessValidatorTestBase, - StatelessValidatorWithSenderTestBase -{ +contract WeightedECDSASignerTest is SignerTestBase, StatelessValidatorTestBase, StatelessValidatorWithSenderTestBase { address guardian1; uint256 guardian1Key; address guardian2; @@ -50,10 +46,13 @@ contract WeightedECDSASignerTest is return abi.encode(guardians, weights, threshold); } - function userOpSignature( - PackedUserOperation memory userOp, - bool valid - ) internal view virtual override returns (bytes memory) { + function userOpSignature(PackedUserOperation memory userOp, bool valid) + internal + view + virtual + override + returns (bytes memory) + { bytes32 userOpHash = ENTRYPOINT.getUserOpHash(userOp); if (!valid) { @@ -106,10 +105,7 @@ contract WeightedECDSASignerTest is } } - function erc1271Signature( - bytes32 hash, - bool valid - ) internal view virtual override returns (address, bytes memory) { + function erc1271Signature(bytes32 hash, bool valid) internal view virtual override returns (address, bytes memory) { if (!valid) { hash = keccak256(abi.encodePacked("invalid", hash)); } @@ -126,10 +122,13 @@ contract WeightedECDSASignerTest is } } - function statelessValidationSignature( - bytes32 hash, - bool valid - ) internal view virtual override returns (address, bytes memory) { + function statelessValidationSignature(bytes32 hash, bool valid) + internal + view + virtual + override + returns (address, bytes memory) + { if (!valid) { hash = keccak256(abi.encodePacked("invalid", hash)); } @@ -149,10 +148,13 @@ contract WeightedECDSASignerTest is return (address(0), signatures); } - function statelessValidationSignatureWithSender( - bytes32 hash, - bool valid - ) internal view virtual override returns (address, bytes memory) { + function statelessValidationSignatureWithSender(bytes32 hash, bool valid) + internal + view + virtual + override + returns (address, bytes memory) + { return statelessValidationSignature(hash, valid); } @@ -161,14 +163,14 @@ contract WeightedECDSASignerTest is function _afterInstallCheck(bytes32 id) internal override { WeightedECDSASigner signerModule = WeightedECDSASigner(address(module)); // Check that the signer was installed by checking totalWeight for this ID - (uint24 totalWeight, , ) = signerModule.weightedStorage(id, WALLET); + (uint24 totalWeight,,) = signerModule.weightedStorage(id, WALLET); assertTrue(totalWeight > 0); } function _afterUninstallCheck(bytes32 id) internal override { WeightedECDSASigner signerModule = WeightedECDSASigner(address(module)); // Check that the signer was uninstalled by checking totalWeight is 0 for this ID - (uint24 totalWeight, , ) = signerModule.weightedStorage(id, WALLET); + (uint24 totalWeight,,) = signerModule.weightedStorage(id, WALLET); assertEq(totalWeight, 0); } diff --git a/test/base/PolicyTestBase.sol b/test/base/PolicyTestBase.sol index 34e8099..40cd916 100644 --- a/test/base/PolicyTestBase.sol +++ b/test/base/PolicyTestBase.sol @@ -9,7 +9,7 @@ import {ModuleTestBase} from "./ModuleTestBase.sol"; import {MODULE_TYPE_POLICY} from "src/types/Constants.sol"; abstract contract PolicyTestBase is ModuleTestBase { - function policyId() internal view virtual returns(bytes32) { + function policyId() internal view virtual returns (bytes32) { return keccak256(abi.encodePacked("POLICY_ID_1")); } diff --git a/test/base/SignerTestBase.sol b/test/base/SignerTestBase.sol index db98797..07d6ddf 100644 --- a/test/base/SignerTestBase.sol +++ b/test/base/SignerTestBase.sol @@ -7,23 +7,20 @@ import {IEntryPoint} from "account-abstraction/interfaces/IEntryPoint.sol"; import {EntryPointLib} from "../utils/EntryPointLib.sol"; import {ModuleTestBase} from "./ModuleTestBase.sol"; import {MODULE_TYPE_SIGNER} from "src/types/Constants.sol"; + abstract contract SignerTestBase is ModuleTestBase { - function signerId() internal view virtual returns(bytes32) { + function signerId() internal view virtual returns (bytes32) { return keccak256(abi.encodePacked("SIGNER_ID_1")); } - function userOpSignature(PackedUserOperation memory userOp, bool valid) - internal - view - virtual - returns (bytes memory); + function userOpSignature(PackedUserOperation memory userOp, bool valid) internal view virtual returns (bytes memory); function erc1271Signature(bytes32 hash, bool valid) internal view virtual returns (address sender, bytes memory signature); - + function testModuleTypeSigner() public view { ISigner signerModule = ISigner(address(module)); bool result = signerModule.isModuleType(MODULE_TYPE_SIGNER); // 6 is the module type for Signer @@ -105,7 +102,8 @@ abstract contract SignerTestBase is ModuleTestBase { userOp.signature = userOpSignature(userOp, true); vm.startPrank(WALLET); - uint256 validationResult = signerModule.checkUserOpSignature(signerId(), userOp, ENTRYPOINT.getUserOpHash(userOp)); + uint256 validationResult = + signerModule.checkUserOpSignature(signerId(), userOp, ENTRYPOINT.getUserOpHash(userOp)); vm.stopPrank(); assertEq(validationResult, 0); } @@ -132,7 +130,8 @@ abstract contract SignerTestBase is ModuleTestBase { userOp.signature = userOpSignature(userOp, false); vm.startPrank(WALLET); - uint256 validationResult = signerModule.checkUserOpSignature(signerId(), userOp, ENTRYPOINT.getUserOpHash(userOp)); + uint256 validationResult = + signerModule.checkUserOpSignature(signerId(), userOp, ENTRYPOINT.getUserOpHash(userOp)); vm.stopPrank(); assertFalse(validationResult == 0); } @@ -166,4 +165,4 @@ abstract contract SignerTestBase is ModuleTestBase { vm.stopPrank(); assertFalse(result == 0x1626ba7e); // ERC1271_INVALID } -} \ No newline at end of file +} diff --git a/test/base/StatelessValidatorTestBase.sol b/test/base/StatelessValidatorTestBase.sol index 98eff10..48d8239 100644 --- a/test/base/StatelessValidatorTestBase.sol +++ b/test/base/StatelessValidatorTestBase.sol @@ -9,14 +9,13 @@ import {ModuleTestBase} from "./ModuleTestBase.sol"; import {MODULE_TYPE_STATELESS_VALIDATOR} from "src/types/Constants.sol"; abstract contract StatelessValidatorTestBase is ModuleTestBase { - function statelessValidationSignature(bytes32 hash, bool valid) internal view virtual returns (address, bytes memory); - function testModuleTypeStatelessValidator() public view - { + + function testModuleTypeStatelessValidator() public view { IStatelessValidator validatorModule = IStatelessValidator(address(module)); bool result = validatorModule.isModuleType(MODULE_TYPE_STATELESS_VALIDATOR); // MODULE_TYPE_STATELESS_VALIDATOR = 4 assertTrue(result); @@ -39,7 +38,7 @@ abstract contract StatelessValidatorTestBase is ModuleTestBase { IStatelessValidator validatorModule = IStatelessValidator(address(module)); bytes32 message = keccak256(abi.encodePacked("TEST_MESSAGE")); - (,bytes memory sig) = statelessValidationSignature(message, false); + (, bytes memory sig) = statelessValidationSignature(message, false); vm.startPrank(WALLET); bool result = validatorModule.validateSignatureWithData(message, sig, installData()); @@ -47,5 +46,4 @@ abstract contract StatelessValidatorTestBase is ModuleTestBase { assertFalse(result); } - } diff --git a/test/base/StatelessValidatorWithSenderTestBase.sol b/test/base/StatelessValidatorWithSenderTestBase.sol index bd6d497..920b1b2 100644 --- a/test/base/StatelessValidatorWithSenderTestBase.sol +++ b/test/base/StatelessValidatorWithSenderTestBase.sol @@ -1,6 +1,5 @@ pragma solidity ^0.8.0; - import {Test} from "forge-std/Test.sol"; import {IStatelessValidatorWithSender} from "src/interfaces/IERC7579Modules.sol"; import {PackedUserOperation} from "account-abstraction/interfaces/PackedUserOperation.sol"; @@ -10,14 +9,13 @@ import {EntryPointLib} from "../utils/EntryPointLib.sol"; import {ModuleTestBase} from "./ModuleTestBase.sol"; abstract contract StatelessValidatorWithSenderTestBase is ModuleTestBase { - function statelessValidationSignatureWithSender(bytes32 hash, bool valid) internal view virtual returns (address, bytes memory); - - function testModuleTypeStatelessValidatorWithSender() public view{ + + function testModuleTypeStatelessValidatorWithSender() public view { IStatelessValidatorWithSender validatorModule = IStatelessValidatorWithSender(address(module)); bool result = validatorModule.isModuleType(MODULE_TYPE_STATELESS_VALIDATOR_WITH_SENDER); assertTrue(result); @@ -27,7 +25,7 @@ abstract contract StatelessValidatorWithSenderTestBase is ModuleTestBase { IStatelessValidatorWithSender validatorModule = IStatelessValidatorWithSender(address(module)); bytes32 message = keccak256(abi.encodePacked("TEST_MESSAGE")); - (address caller, bytes memory sig) = statelessValidationSignatureWithSender( message, true); + (address caller, bytes memory sig) = statelessValidationSignatureWithSender(message, true); vm.startPrank(WALLET); bool result = validatorModule.validateSignatureWithDataWithSender(caller, message, sig, installData()); @@ -40,7 +38,7 @@ abstract contract StatelessValidatorWithSenderTestBase is ModuleTestBase { IStatelessValidatorWithSender validatorModule = IStatelessValidatorWithSender(address(module)); bytes32 message = keccak256(abi.encodePacked("TEST_MESSAGE")); - (address caller, bytes memory sig) = statelessValidationSignatureWithSender( message, false); + (address caller, bytes memory sig) = statelessValidationSignatureWithSender(message, false); vm.startPrank(WALLET); bool result = validatorModule.validateSignatureWithDataWithSender(caller, message, sig, installData()); diff --git a/test/base/ValidatorTestBase.sol b/test/base/ValidatorTestBase.sol index f55c8b5..8aee860 100644 --- a/test/base/ValidatorTestBase.sol +++ b/test/base/ValidatorTestBase.sol @@ -9,11 +9,7 @@ import {ModuleTestBase} from "./ModuleTestBase.sol"; import {MODULE_TYPE_VALIDATOR} from "src/types/Constants.sol"; abstract contract ValidatorTestBase is ModuleTestBase { - function userOpSignature(PackedUserOperation memory userOp, bool valid) - internal - view - virtual - returns (bytes memory); + function userOpSignature(PackedUserOperation memory userOp, bool valid) internal view virtual returns (bytes memory); function erc1271Signature(bytes32 hash, bool valid) internal