diff --git a/Cargo.lock b/Cargo.lock index 6ae2d6e7f0..3ca9fe653f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -251,7 +251,7 @@ checksum = "64b728d511962dda67c1bc7ea7c03736ec275ed2cf4c35d9585298ac9ccf3b73" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -276,7 +276,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -293,7 +293,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "syn-solidity", "tiny-keccak", ] @@ -312,7 +312,7 @@ dependencies = [ "proc-macro2", "quote", "serde_json", - "syn 2.0.106", + "syn 2.0.107", "syn-solidity", ] @@ -633,7 +633,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62945a2f7e6de02a31fe400aa489f0e0f5b2502e69f95f853adb82a96c7a6b60" dependencies = [ "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -671,7 +671,7 @@ dependencies = [ "num-traits", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -785,7 +785,7 @@ checksum = "213888f660fddcca0d257e88e54ac05bca01885f258ccdf695bafd77031bb69d" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -847,7 +847,7 @@ checksum = "9035ad2d096bed7955a320ee7e2230574d28fd3c3a0f186cbea1ff3c7eed5dbb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -883,7 +883,7 @@ checksum = "ffdcb70bdbc4d478427380519163274ac86e52916e10f0a8889adf0f96d3fee7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -1431,7 +1431,7 @@ dependencies = [ "regex", "rustc-hash 1.1.0", "shlex", - "syn 2.0.106", + "syn 2.0.107", "which", ] @@ -1486,7 +1486,7 @@ checksum = "ffebfc2d28a12b262c303cb3860ee77b91bd83b1f20f0bd2a9693008e2f55a9e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -1642,7 +1642,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -1665,7 +1665,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -1959,7 +1959,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2017,7 +2017,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2404,7 +2404,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2418,7 +2418,7 @@ dependencies = [ "proc-macro2", "quote", "strsim", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2429,7 +2429,7 @@ checksum = "fc34b93ccb385b40dc71c6fceac4b2ad23662c7eeb248cf10d529b7e055b6ead" dependencies = [ "darling_core 0.20.11", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2440,7 +2440,7 @@ checksum = "ce154b9bea7fb0c8e8326e62d00354000c36e79770ff21b8c84e3aa267d9d531" dependencies = [ "darling_core 0.21.2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2508,7 +2508,7 @@ checksum = "d150dea618e920167e5973d70ae6ece4385b7164e0d799fe7c122dd0a5d912ad" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2519,7 +2519,7 @@ checksum = "2cdc8d50f426189eef89dac62fabfa0abb27d5cc008f25bf4156a0203325becc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2530,7 +2530,7 @@ checksum = "ef941ded77d15ca19b40374869ac6000af1c9f2a4c0f3d4c70926287e6364a8f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2543,7 +2543,7 @@ dependencies = [ "proc-macro2", "quote", "rustc_version 0.4.1", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2572,7 +2572,7 @@ checksum = "cb7330aeadfbe296029522e6c40f315320aba36fc43a5b3632f3795348f3bd22" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "unicode-xid", ] @@ -2584,7 +2584,7 @@ checksum = "bda628edc44c4bb645fbe0f758797143e4e07926f7ebf4e9bdfbd3d2ce621df3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "unicode-xid", ] @@ -2607,7 +2607,7 @@ dependencies = [ "dsl_auto_type", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2616,7 +2616,7 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "209c735641a413bc68c4923a9d6ad4bcb3ca306b794edaa7eb0b3228a99ffb25" dependencies = [ - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2690,7 +2690,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2701,7 +2701,7 @@ checksum = "8dc51d98e636f5e3b0759a39257458b22619cac7e96d932da6eeb052891bb67c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2721,7 +2721,7 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2787,7 +2787,7 @@ dependencies = [ "enum-ordinalize", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2900,7 +2900,7 @@ checksum = "0d28318a75d4aead5c4db25382e8ef717932d0346600cacae6357eb5941bc5ff" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2912,7 +2912,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -2923,7 +2923,7 @@ checksum = "2f9ed6b3789237c8a0c1c505af1c7eb2c560df6186f01b098c3a1064ea532f38" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -3367,7 +3367,7 @@ checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -3450,7 +3450,7 @@ dependencies = [ "proc-macro-error2", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -4180,7 +4180,7 @@ checksum = "a0eb5a3343abf848c0984fe4604b2b105da9539376e24fc0a3b0007411ae4fd9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -4386,8 +4386,8 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-ecc-transpiler", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -4650,7 +4650,7 @@ checksum = "1b27834086c65ec3f9387b096d66e99f221cf081c2b738042aa252bcd41204e3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -4662,6 +4662,16 @@ dependencies = [ "regex-automata", ] +[[package]] +name = "matrixmultiply" +version = "0.3.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a06de3016e9fae57a36fd14dba131fccf49f74b40b7fbdb472f96e361ec71a08" +dependencies = [ + "autocfg", + "rawpointer", +] + [[package]] name = "maybe-rayon" version = "0.1.1" @@ -4799,6 +4809,21 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "ndarray" +version = "0.16.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "882ed72dce9365842bf196bdeedf5055305f11fc8c03dee7bb0194a6cad34841" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "portable-atomic", + "portable-atomic-util", + "rawpointer", +] + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -5004,7 +5029,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5156,7 +5181,7 @@ version = "1.4.1-rc.2" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5182,7 +5207,7 @@ dependencies = [ "num-prime", "openvm-macros-common", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5243,8 +5268,8 @@ dependencies = [ "openvm-rv32im-circuit", "openvm-rv32im-transpiler", "openvm-sdk", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-transpiler", "rand 0.8.5", @@ -5411,7 +5436,7 @@ dependencies = [ "itertools 0.14.0", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5438,8 +5463,10 @@ name = "openvm-circuit-primitives-derive" version = "1.4.1-rc.2" dependencies = [ "itertools 0.14.0", + "ndarray", + "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5517,7 +5544,7 @@ version = "0.1.0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5598,7 +5625,7 @@ version = "1.4.1-rc.2" dependencies = [ "openvm-macros-common", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5665,7 +5692,7 @@ dependencies = [ "quote", "strum", "strum_macros", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5738,7 +5765,7 @@ dependencies = [ name = "openvm-macros-common" version = "1.4.1-rc.2" dependencies = [ - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -5826,7 +5853,7 @@ name = "openvm-native-compiler-derive" version = "1.4.1-rc.2" dependencies = [ "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -6149,8 +6176,8 @@ dependencies = [ "openvm-pairing-transpiler", "openvm-rv32im-circuit", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-transpiler", @@ -6177,9 +6204,10 @@ dependencies = [ "openvm-circuit", "openvm-instructions", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-guest", - "openvm-sha256-transpiler", + "openvm-sha2-air", + "openvm-sha2-circuit", + "openvm-sha2-guest", + "openvm-sha2-transpiler", "openvm-stark-sdk", "openvm-toolchain-tests", "openvm-transpiler", @@ -6187,11 +6215,15 @@ dependencies = [ ] [[package]] -name = "openvm-sha256-air" +name = "openvm-sha2-air" version = "1.4.1-rc.2" dependencies = [ + "itertools 0.14.0", + "ndarray", + "num_enum", "openvm-circuit", "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", @@ -6199,23 +6231,27 @@ dependencies = [ ] [[package]] -name = "openvm-sha256-circuit" +name = "openvm-sha2-circuit" version = "1.4.1-rc.2" dependencies = [ "cfg-if", "derive-new 0.6.0", "derive_more 1.0.0", "hex", + "itertools 0.14.0", + "ndarray", + "num_enum", "openvm-circuit", "openvm-circuit-derive", "openvm-circuit-primitives", + "openvm-circuit-primitives-derive", "openvm-cuda-backend", "openvm-cuda-builder", "openvm-cuda-common", "openvm-instructions", "openvm-rv32im-circuit", - "openvm-sha256-air", - "openvm-sha256-transpiler", + "openvm-sha2-air", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "rand 0.8.5", @@ -6225,19 +6261,19 @@ dependencies = [ ] [[package]] -name = "openvm-sha256-guest" +name = "openvm-sha2-guest" version = "1.4.1-rc.2" dependencies = [ "openvm-platform", ] [[package]] -name = "openvm-sha256-transpiler" +name = "openvm-sha2-transpiler" version = "1.4.1-rc.2" dependencies = [ "openvm-instructions", "openvm-instructions-derive", - "openvm-sha256-guest", + "openvm-sha2-guest", "openvm-stark-backend", "openvm-transpiler", "rrs-lib", @@ -6418,8 +6454,8 @@ dependencies = [ "openvm-ecc-sw-macros", "openvm-ecc-transpiler", "openvm-rv32im-transpiler", - "openvm-sha256-circuit", - "openvm-sha256-transpiler", + "openvm-sha2-circuit", + "openvm-sha2-transpiler", "openvm-stark-backend", "openvm-stark-sdk", "openvm-toolchain-tests", @@ -6824,7 +6860,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -6941,7 +6977,7 @@ dependencies = [ "proc-macro2", "proc-macro2-diagnostics", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -7010,7 +7046,7 @@ dependencies = [ "phf_shared", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -7094,6 +7130,15 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f84267b20a16ea918e43c6a88433c2d54fa145c92a811b5b047ccbe153674483" +[[package]] +name = "portable-atomic-util" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8a2f0d8d040d7848a709caf78912debcc3f33ee4b3cac47d73d1e1069e83507" +dependencies = [ + "portable-atomic", +] + [[package]] name = "poseidon-primitives" version = "0.2.0" @@ -7189,7 +7234,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" dependencies = [ "proc-macro2", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -7243,7 +7288,7 @@ dependencies = [ "proc-macro-error-attr2", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -7263,7 +7308,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "version_check", "yansi 1.0.1", ] @@ -7472,6 +7517,12 @@ dependencies = [ "bitflags 2.9.2", ] +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + [[package]] name = "rayon" version = "1.11.0" @@ -7529,7 +7580,7 @@ checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -8207,7 +8258,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -8406,7 +8457,7 @@ checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -8482,7 +8533,7 @@ dependencies = [ "darling 0.20.11", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -8840,7 +8891,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -8908,9 +8959,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.106" +version = "2.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ede7c438028d4436d71104916910f5bb611972c5cfd7f89b8300a8186e6fada6" +checksum = "2a26dbd934e5451d21ef060c018dae56fc073894c5a7896f882928a76e6d081b" dependencies = [ "proc-macro2", "quote", @@ -8926,7 +8977,7 @@ dependencies = [ "paste", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -8943,7 +8994,7 @@ checksum = "728a70f3dbaf5bab7f0c4b1ac8d7ae5ea60a4b5549c8a5914361c99147a709d2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -9037,7 +9088,7 @@ dependencies = [ "cfg-if", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -9048,7 +9099,7 @@ checksum = "5c89e72a01ed4c579669add59014b9a524d609c0c88c6a585ce37485879f6ffb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "test-case-core", ] @@ -9071,7 +9122,7 @@ checksum = "451b374529930d7601b1eef8d32bc79ae870b6079b069401709c2a8bf9e75f36" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -9100,7 +9151,7 @@ checksum = "4fee6c4efc90059e10f81e6d42c60a18f76588c3d74cb83a0b242a2b6c7504c1" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -9111,7 +9162,7 @@ checksum = "44d29feb33e986b6ea906bd9c3559a856983f92371b3eaa5e83782a351623de0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -9257,7 +9308,7 @@ checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -9428,7 +9479,7 @@ checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -9768,7 +9819,7 @@ dependencies = [ "log", "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "wasm-bindgen-shared", ] @@ -9803,7 +9854,7 @@ checksum = "8ae87ea40c9f689fc23f209965b6fb8a99ad69aeeb0231408be24920604395de" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -9908,7 +9959,7 @@ checksum = "a47fddd13af08290e67f4acabf4b459f647552718f683a7b415d290ac744a836" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -9919,7 +9970,7 @@ checksum = "bd9211b69f8dcdfa817bfd14bf1c97c9188afa36f4750130fcdf3f400eca9fa8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -10258,7 +10309,7 @@ checksum = "38da3c9736e16c5d3c8c597a9aaa5d1fa565d0532ae05e27c24aa62fb32c0ab6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "synstructure", ] @@ -10279,7 +10330,7 @@ checksum = "9ecf5b4cc5364572d7f4c329661bcc82724222973f2cab6f050a4e5c22f75181" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -10299,7 +10350,7 @@ checksum = "d71e5d6e06ab090c67b5e44993ec16b72dcbaabc526db883a360057678b48502" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", "synstructure", ] @@ -10320,7 +10371,7 @@ checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] @@ -10353,7 +10404,7 @@ checksum = "5b96237efa0c878c64bd89c436f661be4e46b2f3eff1ebb976f7ef2321d2f58f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.106", + "syn 2.0.107", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 77494e37a2..80a3bd43c7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,9 +51,9 @@ members = [ "extensions/keccak256/circuit", "extensions/keccak256/transpiler", "extensions/keccak256/guest", - "extensions/sha256/circuit", - "extensions/sha256/transpiler", - "extensions/sha256/guest", + "extensions/sha2/circuit", + "extensions/sha2/transpiler", + "extensions/sha2/guest", "extensions/ecc/circuit", "extensions/ecc/transpiler", "extensions/ecc/guest", @@ -123,7 +123,7 @@ openvm-cuda-common = { git = "https://github.com/openvm-org/stark-backend.git", openvm-sdk = { path = "crates/sdk", default-features = false } openvm-mod-circuit-builder = { path = "crates/circuits/mod-builder", default-features = false } openvm-poseidon2-air = { path = "crates/circuits/poseidon2-air", default-features = false } -openvm-sha256-air = { path = "crates/circuits/sha256-air", default-features = false } +openvm-sha2-air = { path = "crates/circuits/sha2-air", default-features = false } openvm-circuit-primitives = { path = "crates/circuits/primitives", default-features = false } openvm-circuit-primitives-derive = { path = "crates/circuits/primitives/derive", default-features = false } openvm = { path = "crates/toolchain/openvm", default-features = false } @@ -153,9 +153,9 @@ openvm-native-transpiler = { path = "extensions/native/transpiler", default-feat openvm-keccak256-circuit = { path = "extensions/keccak256/circuit", default-features = false } openvm-keccak256-transpiler = { path = "extensions/keccak256/transpiler", default-features = false } openvm-keccak256-guest = { path = "extensions/keccak256/guest", default-features = false } -openvm-sha256-circuit = { path = "extensions/sha256/circuit", default-features = false } -openvm-sha256-transpiler = { path = "extensions/sha256/transpiler", default-features = false } -openvm-sha256-guest = { path = "extensions/sha256/guest", default-features = false } +openvm-sha2-circuit = { path = "extensions/sha2/circuit", default-features = false } +openvm-sha2-transpiler = { path = "extensions/sha2/transpiler", default-features = false } +openvm-sha2-guest = { path = "extensions/sha2/guest", default-features = false } openvm-bigint-circuit = { path = "extensions/bigint/circuit", default-features = false } openvm-bigint-transpiler = { path = "extensions/bigint/transpiler", default-features = false } openvm-bigint-guest = { path = "extensions/bigint/guest", default-features = false } @@ -234,6 +234,8 @@ memmap2 = "0.9.5" libc = "0.2.175" tracing-subscriber = { version = "0.3.20", features = ["std", "env-filter"] } tokio = "1" # >=1.0.0 to allow downstream flexibility +ndarray = { version = "0.16.1", default-features = false } +num_enum = { version = "0.7.4", default-features = false } # default-features = false for no_std for use in guest programs itertools = { version = "0.14.0", default-features = false } diff --git a/benchmarks/execute/Cargo.toml b/benchmarks/execute/Cargo.toml index 5fcf58b1de..c5c42ed14a 100644 --- a/benchmarks/execute/Cargo.toml +++ b/benchmarks/execute/Cargo.toml @@ -26,8 +26,8 @@ openvm-keccak256-circuit.workspace = true openvm-keccak256-transpiler.workspace = true openvm-rv32im-circuit.workspace = true openvm-rv32im-transpiler.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-continuations = { workspace = true } openvm-native-recursion = { workspace = true } openvm-sdk = { workspace = true } diff --git a/benchmarks/execute/benches/execute.rs b/benchmarks/execute/benches/execute.rs index de3354176e..e3c556d65e 100644 --- a/benchmarks/execute/benches/execute.rs +++ b/benchmarks/execute/benches/execute.rs @@ -47,8 +47,8 @@ use openvm_sdk::{ commit::VmCommittedExe, config::{AggregationConfig, DEFAULT_NUM_CHILDREN_INTERNAL, DEFAULT_NUM_CHILDREN_LEAF}, }; -use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; -use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_sha2_circuit::{Sha2, Sha2CpuProverExt, Sha2Executor}; +use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::{ config::{baby_bear_poseidon2::BabyBearPoseidon2Engine, FriParameters}, engine::{StarkEngine, StarkFriEngine}, @@ -106,7 +106,7 @@ pub struct ExecuteConfig { #[extension] pub keccak: Keccak256, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, #[extension] pub modular: ModularExtension, #[extension] @@ -127,7 +127,7 @@ impl Default for ExecuteConfig { io: Rv32Io, bigint: Int256::default(), keccak: Keccak256, - sha256: Sha256, + sha2: Sha2, modular: ModularExtension::new(vec![ bn_config.modulus.clone(), bn_config.scalar.clone(), @@ -187,7 +187,7 @@ where &config.keccak, inventory, )?; - VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; + VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha2, inventory)?; VmProverExtension::::extend_prover( &AlgebraCpuProverExt, &config.modular, @@ -215,7 +215,7 @@ fn create_default_transpiler() -> Transpiler { .with_extension(Rv32MTranspilerExtension) .with_extension(Int256TranspilerExtension) .with_extension(Keccak256TranspilerExtension) - .with_extension(Sha256TranspilerExtension) + .with_extension(Sha2TranspilerExtension) .with_extension(ModularTranspilerExtension) .with_extension(Fp2TranspilerExtension) .with_extension(EccTranspilerExtension) diff --git a/benchmarks/guest/kitchen-sink/openvm.toml b/benchmarks/guest/kitchen-sink/openvm.toml index 2d1b307eef..e6cafcf57f 100644 --- a/benchmarks/guest/kitchen-sink/openvm.toml +++ b/benchmarks/guest/kitchen-sink/openvm.toml @@ -2,7 +2,7 @@ [app_vm_config.rv32m] [app_vm_config.io] [app_vm_config.keccak] -[app_vm_config.sha256] +[app_vm_config.sha2] [app_vm_config.bigint] [app_vm_config.modular] diff --git a/benchmarks/guest/sha256/openvm.toml b/benchmarks/guest/sha256/openvm.toml index 656bf52414..35f92b7195 100644 --- a/benchmarks/guest/sha256/openvm.toml +++ b/benchmarks/guest/sha256/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/benchmarks/guest/sha256_iter/openvm.toml b/benchmarks/guest/sha256_iter/openvm.toml index 656bf52414..35f92b7195 100644 --- a/benchmarks/guest/sha256_iter/openvm.toml +++ b/benchmarks/guest/sha256_iter/openvm.toml @@ -1,4 +1,4 @@ [app_vm_config.rv32i] [app_vm_config.rv32m] [app_vm_config.io] -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/crates/circuits/primitives/derive/Cargo.toml b/crates/circuits/primitives/derive/Cargo.toml index 06d4c00aed..23ec5d559d 100644 --- a/crates/circuits/primitives/derive/Cargo.toml +++ b/crates/circuits/primitives/derive/Cargo.toml @@ -12,6 +12,13 @@ license.workspace = true proc-macro = true [dependencies] -syn = { version = "2.0", features = ["parsing"] } +syn = { version = "2.0", features = ["full", "parsing", "extra-traits"] } quote = "1.0" -itertools = { workspace = true } +itertools = { workspace = true, default-features = true } +proc-macro2 = "1.0" + +[dev-dependencies] +ndarray.workspace = true + +[package.metadata.cargo-shear] +ignored = ["ndarray"] diff --git a/crates/circuits/primitives/derive/src/cols_ref/README.md b/crates/circuits/primitives/derive/src/cols_ref/README.md new file mode 100644 index 0000000000..82812f7b90 --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/README.md @@ -0,0 +1,113 @@ +# ColsRef macro + +The `ColsRef` procedural macro is used in constraint generation to create column structs that have dynamic sizes. + +Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384). +See the [SHA-2 VM extension](../../../../../../extensions/sha2/circuit/src/sha2_chip/air.rs) for an example of how to use the `ColsRef` macro to reuse constraint generation code over multiple circuits. + +## Overview + +As an illustrative example, consider the following columns struct: +```rust +struct ExampleCols { + arr: [T; N], + sum: T, +} +``` +Let's say we want to constrain `sum` to be the sum of the elements of `arr`, and `N` can be either 5 or 10. +We can define a trait that stores the config parameters. +```rust +pub trait ExampleConfig { + const N: usize; +} +``` +and then implement it for the two different configs. +```rust +pub struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} +pub struct ExampleConfigImplB; +impl ExampleConfig for ExampleConfigImplB { + const N: usize = 10; +} +``` +Then we can use the `ColsRef` macro like this +```rust +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + sum: T, +} +``` +which will generate a columns struct that uses references to the fields. +```rust +struct ExampleColsRef<'a, T, const N: usize> { + arr: ndarray::ArrayView1<'a, T>, // an n-dimensional view into the input slice (ArrayView2 for 2D arrays, etc.) + sum: &'a T, +} +``` +The `ColsRef` macro will also generate a `from` method that takes a slice of the correct length and returns an instance of the columns struct. +The `from` method is parameterized by a struct that implements the `ExampleConfig` trait, and it uses the associated constants to determine how to split the input slice into the fields of the columns struct. + +So, the constraint generation code can be written as +```rust +impl Air for ExampleAir { + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, _) = (main.row_slice(0), main.row_slice(1)); + let local_cols = ExampleColsRef::::from::(&local[..C::N + 1]); + let sum = local_cols.arr.iter().sum(); + builder.assert_eq(local_cols.sum, sum); + } +} +``` +Notes: +- the `arr` and `sum` fields of `ExampleColsRef` are references to the elements of the `local` slice. +- the name, `N`, of the const generic parameter must match the name of the associated constant `N` in the `ExampleConfig` trait. + +The `ColsRef` macro also generates a `ExampleColsRefMut` struct that stores mutable references to the fields, for use in trace generation. + +The `ColsRef` macro supports more than just variable-length array fields. +The field types can also be: +- any type that derives `AlignedBorrow` via `#[derive(AlignedBorrow)]` +- any type that derives `ColsRef` via `#[derive(ColsRef)]` +- (possibly nested) arrays of `T` or (possibly nested) arrays of a type that derives `AlignedBorrow` + +Note that we currently do not support arrays of types that derive `ColsRef`. + +## Specification + +Annotating a struct named `ExampleCols` with `#[derive(ColsRef)]` and `#[config(ExampleConfig)]` produces two structs, `ExampleColsRef` and `ExampleColsRefMut`. +- we assume `ExampleCols` has exactly one generic type parameter, typically named `T`, and any number of const generic parameters. Each const generic parameter must have a name that matches an associated constant in the `ExampleConfig` trait + +The fields of `ExampleColsRef` have the same names as the fields of `ExampleCols`, but their types are transformed as follows: +- type `T` becomes `&T` +- type `[T; LEN]` becomes `&ArrayView1` (see [ndarray](https://docs.rs/ndarray/latest/ndarray/index.html)) where `LEN` is an associated constant in `ExampleConfig` + - the `ExampleColsRef::from` method will correctly infer the length of the array from the config +- fields with names that end in `Cols` are assumed to be a columns struct that derives `ColsRef` and are transformed into the appropriate `ColsRef` type recursively + - one restriction is that any nested `ColsRef` type must have the same config as the outer `ColsRef` type +- fields that are annotated with `#[aligned_borrow]` are assumed to derive `AlignedBorrow` and are borrowed from the input slice. The new type is a reference to the `AlignedBorrow` type + - if a field whose name ends in `Cols` is annotated with `#[aligned_borrow]`, then the aligned borrow takes precedence, and the field is not transformed into an `ArrayView` +- nested arrays of `U` become `&ArrayViewX` where `X` is the number of dimensions in the nested array type + - `U` can be either the generic type `T` or a type that derives `AlignedBorrow`. In the latter case, the field must be annotated with `#[aligned_borrow]` + - the `ArrayViewX` type provides a `X`-dimensional view into the row slice + +The fields of `ExampleColsRefMut` are almost the same as the fields of `ExampleColsRef`, but they are mutable references. +- the `ArrayViewMutX` type is used instead of `ArrayViewX` for the array fields. +- fields that derive `ColsRef` are transformed into the appropriate `ColsRefMut` type recursively. + +Each of the `ExampleColsRef` and `ExampleColsRefMut` types has the following methods implemented: +```rust +// Takes a slice of the correct length and returns an instance of the columns struct. +pub const fn from(slice: &[T]) -> Self; +// Returns the number of cells in the struct +pub const fn width() -> usize; +``` +Note that the `width` method on both structs returns the same value. + +Additionally, the `ExampleColsRef` struct has a `from_mut` method that takes a `ExampleColsRefMut` and returns a `ExampleColsRef`. +This may be useful in trace generation to pass a `ExampleColsRefMut` to a function that expects a `ExampleColsRef`. + +See the [tests](../../tests/test_cols_ref.rs) for concrete examples of how the `ColsRef` macro handles each of the supported field types. \ No newline at end of file diff --git a/crates/circuits/primitives/derive/src/cols_ref/mod.rs b/crates/circuits/primitives/derive/src/cols_ref/mod.rs new file mode 100644 index 0000000000..f6c0028462 --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/mod.rs @@ -0,0 +1,709 @@ +/* + * The `ColsRef` procedural macro is used in constraint generation to create column structs that + * have dynamic sizes. + * + * Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the + * same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384). + * See the [SHA-2 VM extension](openvm/extensions/sha2/circuit/src/sha2_chip/air.rs) for an + * example of how to use the `ColsRef` macro to reuse constraint generation code over multiple + * circuits. + * + * This macro can also be used in other situations where we want to derive Borrow for &[u8], + * for some complicated struct T. + */ +mod utils; + +use utils::*; + +extern crate proc_macro; + +use itertools::Itertools; +use quote::{format_ident, quote}; +use syn::{parse_quote, DeriveInput}; + +pub fn cols_ref_impl( + derive_input: DeriveInput, + config: proc_macro2::Ident, +) -> proc_macro2::TokenStream { + let DeriveInput { + ident, + generics, + data, + vis, + .. + } = derive_input; + + let generic_types = generics + .params + .iter() + .filter_map(|p| { + if let syn::GenericParam::Type(type_param) = p { + Some(type_param) + } else { + None + } + }) + .collect::>(); + + if generic_types.len() != 1 { + panic!("Struct must have exactly one generic type parameter"); + } + + let generic_type = generic_types[0]; + + let const_generics = generics.const_params().map(|p| &p.ident).collect_vec(); + + match data { + syn::Data::Struct(data_struct) => { + // Process the fields of the struct, transforming the types for use in ColsRef struct + let const_field_infos: Vec = data_struct + .fields + .iter() + .map(|f| get_const_cols_ref_fields(f, generic_type, &const_generics)) + .collect_vec(); + + // The ColsRef struct is named by appending `Ref` to the struct name + let const_cols_ref_name = syn::Ident::new(&format!("{}Ref", ident), ident.span()); + + // the args to the `from` method will be different for the ColsRef and ColsRefMut + // structs + let from_args = quote! { slice: &'a [#generic_type] }; + + // Package all the necessary information to generate the ColsRef struct + let struct_info = StructInfo { + name: const_cols_ref_name, + vis: vis.clone(), + generic_type: generic_type.clone(), + field_infos: const_field_infos, + fields: data_struct.fields.clone(), + from_args, + derive_clone: true, + }; + + // Generate the ColsRef struct + let const_cols_ref_struct = make_struct(struct_info.clone(), &config); + + // Generate the `from_mut` method for the ColsRef struct + let from_mut_impl = make_from_mut(struct_info, &config); + + // Process the fields of the struct, transforming the types for use in ColsRefMut struct + let mut_field_infos: Vec = data_struct + .fields + .iter() + .map(|f| get_mut_cols_ref_fields(f, generic_type, &const_generics)) + .collect_vec(); + + // The ColsRefMut struct is named by appending `RefMut` to the struct name + let mut_cols_ref_name = syn::Ident::new(&format!("{}RefMut", ident), ident.span()); + + // the args to the `from` method will be different for the ColsRef and ColsRefMut + // structs + let from_args = quote! { slice: &'a mut [#generic_type] }; + + // Package all the necessary information to generate the ColsRefMut struct + let struct_info = StructInfo { + name: mut_cols_ref_name, + vis, + generic_type: generic_type.clone(), + field_infos: mut_field_infos, + fields: data_struct.fields, + from_args, + derive_clone: false, + }; + + // Generate the ColsRefMut struct + let mut_cols_ref_struct = make_struct(struct_info, &config); + + quote! { + #const_cols_ref_struct + #from_mut_impl + #mut_cols_ref_struct + } + } + _ => panic!("ColsRef can only be derived for structs"), + } +} + +#[derive(Debug, Clone)] +struct StructInfo { + name: syn::Ident, + vis: syn::Visibility, + generic_type: syn::TypeParam, + field_infos: Vec, + fields: syn::Fields, + from_args: proc_macro2::TokenStream, + derive_clone: bool, +} + +// Generate the ColsRef and ColsRefMut structs, depending on the value of `struct_info` +// This function is meant to reduce code duplication between the code needed to generate the two +// structs Notable differences between the two structs are: +// - the types of the fields +// - ColsRef derives Clone, but ColsRefMut cannot (since it stores mutable references) +// - the `from` method parameter is a reference to a slice for ColsRef and a mutable reference to +// a slice for ColsRefMut +fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream { + let StructInfo { + name, + vis, + generic_type, + field_infos, + fields, + from_args, + derive_clone, + } = struct_info; + + let field_types = field_infos.iter().map(|f| &f.ty).collect_vec(); + let length_exprs = field_infos.iter().map(|f| &f.length_expr).collect_vec(); + let prepare_subslices = field_infos + .iter() + .map(|f| &f.prepare_subslice) + .collect_vec(); + let initializers = field_infos.iter().map(|f| &f.initializer).collect_vec(); + + let idents = fields.iter().map(|f| &f.ident).collect_vec(); + + let clone_impl = if derive_clone { + quote! { + #[derive(Clone)] + } + } else { + quote! {} + }; + + quote! { + #clone_impl + #[derive(Debug)] + #vis struct #name <'a, #generic_type> { + #( pub #idents: #field_types ),* + } + + impl<'a, #generic_type> #name<'a, #generic_type> { + pub fn from(#from_args) -> Self { + #( #prepare_subslices )* + Self { + #( #idents: #initializers ),* + } + } + + // Returns number of cells in the struct (where each cell has type T). + // This method should only be called if the struct has no primitive types (i.e. for columns structs). + pub const fn width() -> usize { + 0 #( + #length_exprs )* + } + } + } +} + +// Generate the `from_mut` method for the ColsRef struct +fn make_from_mut(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_macro2::TokenStream { + let StructInfo { + name, + vis: _, + generic_type, + field_infos: _, + fields, + from_args: _, + derive_clone: _, + } = struct_info; + + let from_mut_impl = fields + .iter() + .map(|f| { + let ident = f.ident.clone().unwrap(); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + // calling view() on ArrayViewMut returns an ArrayView + quote! { + other.#ident.view() + } + } else if derives_aligned_borrow { + // implicitly converts a mutable reference to an immutable reference, so leave the + // field value unchanged + quote! { + other.#ident + } + } else if is_columns_struct(&f.ty) { + // lifetime 'b is used in from_mut to allow more flexible lifetime of return value + let cols_ref_type = + get_const_cols_ref_type(&f.ty, &generic_type, parse_quote! { 'b }); + // Recursively call `from_mut` on the ColsRef field + quote! { + <#cols_ref_type>::from_mut::(&other.#ident) + } + } else if is_generic_type(&f.ty, &generic_type) { + // implicitly converts a mutable reference to an immutable reference, so leave the + // field value unchanged + quote! { + &other.#ident + } + } else { + panic!("Unsupported field type (in make_from_mut): {:?}", f.ty); + } + }) + .collect_vec(); + + let field_idents = fields + .iter() + .map(|f| f.ident.clone().unwrap()) + .collect_vec(); + + let mut_struct_ident = format_ident!("{}Mut", name.to_string()); + let mut_struct_type: syn::Type = parse_quote! { + #mut_struct_ident<'a, #generic_type> + }; + + parse_quote! { + // lifetime 'b is used in from_mut to allow more flexible lifetime of return value + impl<'b, #generic_type> #name<'b, #generic_type> { + pub fn from_mut<'a, C: #config>(other: &'b #mut_struct_type) -> Self + { + Self { + #( #field_idents: #from_mut_impl ),* + } + } + } + } +} + +// Information about a field that is used to generate the ColsRef and ColsRefMut structs +// See the `make_struct` function to see how this information is used +#[derive(Debug, Clone)] +struct FieldInfo { + // type for struct definition + ty: syn::Type, + // an expr calculating the length of the field + length_expr: proc_macro2::TokenStream, + // prepare a subslice of the slice to be used in the 'from' method + prepare_subslice: proc_macro2::TokenStream, + // an expr used in the Self initializer in the 'from' method + // may refer to the subslice declared in prepare_subslice + initializer: proc_macro2::TokenStream, +} + +// Prepare the fields for the const ColsRef struct +fn get_const_cols_ref_fields( + f: &syn::Field, + generic_type: &syn::TypeParam, + const_generics: &[&syn::Ident], +) -> FieldInfo { + let length_var = format_ident!("{}_length", f.ident.clone().unwrap()); + let slice_var = format_ident!("{}_slice", f.ident.clone().unwrap()); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics); + debug_assert!( + !dims.is_empty(), + "Array field must have at least one dimension" + ); + + let ndarray_ident: syn::Ident = format_ident!("ArrayView{}", dims.len()); + let ndarray_type: syn::Type = parse_quote! { + ndarray::#ndarray_ident<'a, #elem_type> + }; + + // dimensions of the array in terms of number of cells + let dim_exprs = dims + .iter() + .map(|d| match d { + // need to prepend C:: for const generic array dimensions + Dimension::ConstGeneric(expr) => quote! { C::#expr }, + Dimension::Other(expr) => quote! { #expr }, + }) + .collect_vec(); + + if derives_aligned_borrow { + let length_expr = quote! { + <#elem_type>::width() #(* #dim_exprs)* + }; + + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(#length_expr); + let #slice_var: &[#elem_type] = unsafe { &*(#slice_var as *const [T] as *const [#elem_type]) }; + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_columns_struct(&elem_type) { + panic!("Arrays of columns structs are currently not supported"); + } else if is_generic_type(&elem_type, generic_type) { + let length_expr = quote! { + 1 #(* #dim_exprs)* + }; + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(#length_expr); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_primitive_type(&elem_type) { + FieldInfo { + ty: parse_quote! { + &'a #elem_type + }, + // Columns structs won't ever have primitive types, but this macro can be used on + // other structs as well, to make it easy to borrow a struct from &[u8]. + // We just set length = 0 knowing that calling the width() method is undefined if + // the struct has a primitive type. + length_expr: quote! { + 0 + }, + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at(std::mem::size_of::<#elem_type>() #(* #dim_exprs)*); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else { + panic!( + "Unsupported field type (in get_const_cols_ref_fields): {:?}", + f.ty + ); + } + } else if derives_aligned_borrow { + // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) + let f_ty = &f.ty; + FieldInfo { + ty: parse_quote! { + &'a #f_ty + }, + length_expr: quote! { + <#f_ty>::width() + }, + prepare_subslice: quote! { + let #length_var = <#f_ty>::width(); + let (#slice_var, slice) = slice.split_at(#length_var); + }, + initializer: quote! { + { + use core::borrow::Borrow; + #slice_var.borrow() + } + }, + } + } else if is_columns_struct(&f.ty) { + let const_cols_ref_type = get_const_cols_ref_type(&f.ty, generic_type, parse_quote! { 'a }); + FieldInfo { + ty: parse_quote! { + #const_cols_ref_type + }, + length_expr: quote! { + <#const_cols_ref_type>::width::() + }, + prepare_subslice: quote! { + let #length_var = <#const_cols_ref_type>::width::(); + let (#slice_var, slice) = slice.split_at(#length_var); + let #slice_var = <#const_cols_ref_type>::from::(#slice_var); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_generic_type(&f.ty, generic_type) { + FieldInfo { + ty: parse_quote! { + &'a #generic_type + }, + length_expr: quote! { + 1 + }, + prepare_subslice: quote! { + let #length_var = 1; + let (#slice_var, slice) = slice.split_at(#length_var); + }, + initializer: quote! { + &#slice_var[0] + }, + } + } else { + panic!( + "Unsupported field type (in get_mut_cols_ref_fields): {:?}", + f.ty + ); + } +} + +// Prepare the fields for the mut ColsRef struct +fn get_mut_cols_ref_fields( + f: &syn::Field, + generic_type: &syn::TypeParam, + const_generics: &[&syn::Ident], +) -> FieldInfo { + let length_var = format_ident!("{}_length", f.ident.clone().unwrap()); + let slice_var = format_ident!("{}_slice", f.ident.clone().unwrap()); + + let derives_aligned_borrow = f + .attrs + .iter() + .any(|attr| attr.path().is_ident("aligned_borrow")); + + let is_array = matches!(f.ty, syn::Type::Array(_)); + + if is_array { + let ArrayInfo { dims, elem_type } = get_array_info(&f.ty, const_generics); + debug_assert!( + !dims.is_empty(), + "Array field must have at least one dimension" + ); + + let ndarray_ident: syn::Ident = format_ident!("ArrayViewMut{}", dims.len()); + let ndarray_type: syn::Type = parse_quote! { + ndarray::#ndarray_ident<'a, #elem_type> + }; + + // dimensions of the array in terms of number of cells + let dim_exprs = dims + .iter() + .map(|d| match d { + // need to prepend C:: for const generic array dimensions + Dimension::ConstGeneric(expr) => quote! { C::#expr }, + Dimension::Other(expr) => quote! { #expr }, + }) + .collect_vec(); + + if derives_aligned_borrow { + let length_expr = quote! { + <#elem_type>::width() #(* #dim_exprs)* + }; + + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut (#length_expr); + let #slice_var: &mut [#elem_type] = unsafe { &mut *(#slice_var as *mut [T] as *mut [#elem_type]) }; + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_columns_struct(&elem_type) { + panic!("Arrays of columns structs are currently not supported"); + } else if is_generic_type(&elem_type, generic_type) { + let length_expr = quote! { + 1 #(* #dim_exprs)* + }; + FieldInfo { + ty: parse_quote! { + #ndarray_type + }, + length_expr: length_expr.clone(), + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut(#length_expr); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_primitive_type(&elem_type) { + FieldInfo { + ty: parse_quote! { + &'a mut #elem_type + }, + // Columns structs won't ever have primitive types, but this macro can be used on + // other structs as well, to make it easy to borrow a struct from &[u8]. + // We just set length = 0 knowing that calling the width() method is undefined if + // the struct has a primitive type. + length_expr: quote! { + 0 + }, + prepare_subslice: quote! { + let (#slice_var, slice) = slice.split_at_mut(std::mem::size_of::<#elem_type>() #(* #dim_exprs)*); + let #slice_var = ndarray::#ndarray_ident::from_shape( ( #(#dim_exprs),* ) , #slice_var).unwrap(); + }, + initializer: quote! { + #slice_var + }, + } + } else { + panic!( + "Unsupported field type (in get_mut_cols_ref_fields): {:?}", + f.ty + ); + } + } else if derives_aligned_borrow { + // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config) + let f_ty = &f.ty; + FieldInfo { + ty: parse_quote! { + &'a mut #f_ty + }, + length_expr: quote! { + <#f_ty>::width() + }, + prepare_subslice: quote! { + let #length_var = <#f_ty>::width(); + let (#slice_var, slice) = slice.split_at_mut(#length_var); + }, + initializer: quote! { + { + use core::borrow::BorrowMut; + #slice_var.borrow_mut() + } + }, + } + } else if is_columns_struct(&f.ty) { + let mut_cols_ref_type = get_mut_cols_ref_type(&f.ty, generic_type); + FieldInfo { + ty: parse_quote! { + #mut_cols_ref_type + }, + length_expr: quote! { + <#mut_cols_ref_type>::width::() + }, + prepare_subslice: quote! { + let #length_var = <#mut_cols_ref_type>::width::(); + let (#slice_var, slice) = slice.split_at_mut(#length_var); + let #slice_var = <#mut_cols_ref_type>::from::(#slice_var); + }, + initializer: quote! { + #slice_var + }, + } + } else if is_generic_type(&f.ty, generic_type) { + FieldInfo { + ty: parse_quote! { + &'a mut #generic_type + }, + length_expr: quote! { + 1 + }, + prepare_subslice: quote! { + let #length_var = 1; + let (#slice_var, slice) = slice.split_at_mut(#length_var); + }, + initializer: quote! { + &mut #slice_var[0] + }, + } + } else { + panic!( + "Unsupported field type (in get_mut_cols_ref_fields): {:?}", + f.ty + ); + } +} + +// Helper functions + +fn is_columns_struct(ty: &syn::Type) -> bool { + if let syn::Type::Path(type_path) = ty { + type_path + .path + .segments + .iter() + .next_back() + .map(|s| s.ident.to_string().ends_with("Cols")) + .unwrap_or(false) + } else { + false + } +} + +// If 'ty' is a struct that derives ColsRef, return the ColsRef struct type +// Otherwise, return None +fn get_const_cols_ref_type( + ty: &syn::Type, + generic_type: &syn::TypeParam, + lifetime: syn::Lifetime, +) -> syn::TypePath { + if !is_columns_struct(ty) { + panic!("Expected a columns struct, got {:?}", ty); + } + + if let syn::Type::Path(type_path) = ty { + let s = type_path.path.segments.iter().next_back().unwrap(); + if s.ident.to_string().ends_with("Cols") { + let const_cols_ref_ident = format_ident!("{}Ref", s.ident); + let const_cols_ref_type = parse_quote! { + #const_cols_ref_ident<#lifetime, #generic_type> + }; + const_cols_ref_type + } else { + panic!("is_columns_struct returned true for type {:?} but the last segment is not a columns struct", ty); + } + } else { + panic!( + "is_columns_struct returned true but the type {:?} is not a path", + ty + ); + } +} + +// If 'ty' is a struct that derives ColsRef, return the ColsRefMut struct type +// Otherwise, return None +fn get_mut_cols_ref_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> syn::TypePath { + if !is_columns_struct(ty) { + panic!("Expected a columns struct, got {:?}", ty); + } + + if let syn::Type::Path(type_path) = ty { + let s = type_path.path.segments.iter().next_back().unwrap(); + if s.ident.to_string().ends_with("Cols") { + let mut_cols_ref_ident = format_ident!("{}RefMut", s.ident); + let mut_cols_ref_type = parse_quote! { + #mut_cols_ref_ident<'a, #generic_type> + }; + mut_cols_ref_type + } else { + panic!("is_columns_struct returned true for type {:?} but the last segment is not a columns struct", ty); + } + } else { + panic!( + "is_columns_struct returned true but the type {:?} is not a path", + ty + ); + } +} + +fn is_generic_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> bool { + if let syn::Type::Path(type_path) = ty { + if type_path.path.segments.len() == 1 { + type_path + .path + .segments + .iter() + .next_back() + .map(|s| s.ident == generic_type.ident) + .unwrap_or(false) + } else { + false + } + } else { + false + } +} diff --git a/crates/circuits/primitives/derive/src/cols_ref/utils.rs b/crates/circuits/primitives/derive/src/cols_ref/utils.rs new file mode 100644 index 0000000000..6203bbd00b --- /dev/null +++ b/crates/circuits/primitives/derive/src/cols_ref/utils.rs @@ -0,0 +1,102 @@ +use syn::{Expr, ExprBlock, ExprPath, Ident, Stmt, Type, TypePath}; + +pub fn is_primitive_type(ty: &Type) -> bool { + match ty { + Type::Path(TypePath { path, .. }) if path.segments.len() == 1 => { + matches!( + path.segments[0].ident.to_string().as_str(), + "u8" | "u16" + | "u32" + | "u64" + | "u128" + | "usize" + | "i8" + | "i16" + | "i32" + | "i64" + | "i128" + | "isize" + | "f32" + | "f64" + | "bool" + | "char" + ) + } + _ => false, + } +} + +// Type of array dimension +pub enum Dimension { + ConstGeneric(Expr), + Other(Expr), +} + +// Describes a nested array +pub struct ArrayInfo { + pub dims: Vec, + pub elem_type: Type, +} + +pub fn get_array_info(ty: &Type, const_generics: &[&Ident]) -> ArrayInfo { + let dims = get_dims(ty, const_generics); + let elem_type = get_elem_type(ty); + ArrayInfo { dims, elem_type } +} + +fn get_elem_type(ty: &Type) -> Type { + match ty { + Type::Array(array) => get_elem_type(array.elem.as_ref()), + Type::Path(_) => ty.clone(), + _ => panic!("Unsupported type: {:?}", ty), + } +} + +// Get a vector of the dimensions of the array +// Each dimension is either a constant generic or a literal integer value +fn get_dims(ty: &Type, const_generics: &[&Ident]) -> Vec { + get_dims_impl(ty, const_generics) + .into_iter() + .rev() + .collect() +} + +fn get_dims_impl(ty: &Type, const_generics: &[&Ident]) -> Vec { + match ty { + Type::Array(array) => { + let mut dims = get_dims_impl(array.elem.as_ref(), const_generics); + match &array.len { + Expr::Block(syn::ExprBlock { block, .. }) => { + if block.stmts.len() != 1 { + panic!( + "Expected exactly one statement in block, got: {:?}", + block.stmts.len() + ); + } + if let Stmt::Expr(Expr::Path(expr_path), ..) = &block.stmts[0] { + if let Some(len_ident) = expr_path.path.get_ident() { + if const_generics.contains(&len_ident) { + dims.push(Dimension::ConstGeneric(expr_path.clone().into())); + } else { + dims.push(Dimension::Other(expr_path.clone().into())); + } + } + } + } + Expr::Path(ExprPath { path, .. }) => { + let len_ident = path.get_ident(); + if len_ident.is_some() && const_generics.contains(&len_ident.unwrap()) { + dims.push(Dimension::ConstGeneric(array.len.clone())); + } else { + dims.push(Dimension::Other(array.len.clone())); + } + } + Expr::Lit(expr_lit) => dims.push(Dimension::Other(expr_lit.clone().into())), + _ => panic!("Unsupported array length type: {:?}", array.len), + } + dims + } + Type::Path(_) => Vec::new(), + _ => panic!("Unsupported field type (in get_dims_impl)"), + } +} diff --git a/crates/circuits/primitives/derive/src/lib.rs b/crates/circuits/primitives/derive/src/lib.rs index 35e5f8fd5b..70db5dc672 100644 --- a/crates/circuits/primitives/derive/src/lib.rs +++ b/crates/circuits/primitives/derive/src/lib.rs @@ -7,6 +7,9 @@ use proc_macro::TokenStream; use quote::quote; use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta}; +mod cols_ref; +use cols_ref::cols_ref_impl; + #[proc_macro_derive(AlignedBorrow)] pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); @@ -426,3 +429,25 @@ pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream { _ => unimplemented!(), } } + +#[proc_macro_derive(ColsRef, attributes(aligned_borrow, config))] +pub fn cols_ref_derive(input: TokenStream) -> TokenStream { + let derive_input: DeriveInput = parse_macro_input!(input as DeriveInput); + + let config = derive_input + .attrs + .iter() + .find(|attr| attr.path().is_ident("config")); + if config.is_none() { + return syn::Error::new(derive_input.ident.span(), "Config attribute is required") + .to_compile_error() + .into(); + } + let config: proc_macro2::Ident = config + .unwrap() + .parse_args() + .expect("Failed to parse config"); + + let res = cols_ref_impl(derive_input, config); + res.into() +} diff --git a/crates/circuits/primitives/derive/tests/debug.rs b/crates/circuits/primitives/derive/tests/debug.rs new file mode 100644 index 0000000000..f2d5fbbc5c --- /dev/null +++ b/crates/circuits/primitives/derive/tests/debug.rs @@ -0,0 +1,27 @@ +use openvm_circuit_primitives_derive::ColsRef; + +trait ExampleConfig { + const N: usize; +} +struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} + +#[allow(dead_code)] +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + // arr: [T; { N }], + sum: T, + // primitive: u32, + // array_of_primitive: [u32; { N }], +} + +#[test] +fn debug() { + let input = [1, 2, 3, 4, 5, 15]; + let test: ExampleColsRef = ExampleColsRef::from::(&input); + println!("{}, {}", test.arr, test.sum); +} diff --git a/crates/circuits/primitives/derive/tests/example.rs b/crates/circuits/primitives/derive/tests/example.rs new file mode 100644 index 0000000000..58bac9e26c --- /dev/null +++ b/crates/circuits/primitives/derive/tests/example.rs @@ -0,0 +1,87 @@ +use openvm_circuit_primitives_derive::ColsRef; + +pub trait ExampleConfig { + const N: usize; +} +pub struct ExampleConfigImplA; +impl ExampleConfig for ExampleConfigImplA { + const N: usize = 5; +} +pub struct ExampleConfigImplB; +impl ExampleConfig for ExampleConfigImplB { + const N: usize = 10; +} + +#[allow(dead_code)] +#[derive(ColsRef)] +#[config(ExampleConfig)] +struct ExampleCols { + arr: [T; N], + sum: T, +} + +#[test] +fn example() { + let input = [1, 2, 3, 4, 5, 15]; + let test: ExampleColsRef = ExampleColsRef::from::(&input); + println!("{}, {}", test.arr, test.sum); +} + +/* + * For reference, this is what the ColsRef macro expands to. + * The `cargo expand` tool is helpful for understanding how the ColsRef macro works. + * See https://github.com/dtolnay/cargo-expand + +#[derive(Debug, Clone)] +struct ExampleColsRef<'a, T> { + pub arr: ndarray::ArrayView1<'a, T>, + pub sum: &'a T, +} + +impl<'a, T> ExampleColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let (arr_slice, slice) = slice.split_at(1 * C::N); + let arr_slice = ndarray::ArrayView1::from_shape((C::N), arr_slice).unwrap(); + let sum_length = 1; + let (sum_slice, slice) = slice.split_at(sum_length); + Self { + arr: arr_slice, + sum: &sum_slice[0], + } + } + pub const fn width() -> usize { + 0 + 1 * C::N + 1 + } +} + +impl<'b, T> ExampleColsRef<'b, T> { + pub fn from_mut<'a, C: ExampleConfig>(other: &'b ExampleColsRefMut<'a, T>) -> Self { + Self { + arr: other.arr.view(), + sum: &other.sum, + } + } +} + +#[derive(Debug)] +struct ExampleColsRefMut<'a, T> { + pub arr: ndarray::ArrayViewMut1<'a, T>, + pub sum: &'a mut T, +} + +impl<'a, T> ExampleColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let (arr_slice, slice) = slice.split_at_mut(1 * C::N); + let arr_slice = ndarray::ArrayViewMut1::from_shape((C::N), arr_slice).unwrap(); + let sum_length = 1; + let (sum_slice, slice) = slice.split_at_mut(sum_length); + Self { + arr: arr_slice, + sum: &mut sum_slice[0], + } + } + pub const fn width() -> usize { + 0 + 1 * C::N + 1 + } +} +*/ diff --git a/crates/circuits/primitives/derive/tests/test_cols_ref.rs b/crates/circuits/primitives/derive/tests/test_cols_ref.rs new file mode 100644 index 0000000000..6bad0c4f9f --- /dev/null +++ b/crates/circuits/primitives/derive/tests/test_cols_ref.rs @@ -0,0 +1,299 @@ +use openvm_circuit_primitives_derive::{AlignedBorrow, ColsRef}; + +pub trait TestConfig { + const N: usize; + const M: usize; +} +pub struct TestConfigImpl; +impl TestConfig for TestConfigImpl { + const N: usize = 5; + const M: usize = 2; +} + +#[allow(dead_code)] // TestCols isn't actually used in the code. silence clippy warning +#[derive(ColsRef)] +#[config(TestConfig)] +struct TestCols { + single_field_element: T, + array_of_t: [T; N], + nested_array_of_t: [[T; N]; N], + cols_struct: TestSubCols, + #[aligned_borrow] + array_of_aligned_borrow: [TestAlignedBorrow; N], + #[aligned_borrow] + nested_array_of_aligned_borrow: [[TestAlignedBorrow; N]; N], +} + +#[allow(dead_code)] // TestSubCols isn't actually used in the code. silence clippy warning +#[derive(ColsRef, Debug)] +#[config(TestConfig)] +struct TestSubCols { + // TestSubCols can have fields of any type that TestCols can have + a: T, + b: [T; M], + #[aligned_borrow] + c: TestAlignedBorrow, +} + +#[derive(AlignedBorrow, Debug)] +struct TestAlignedBorrow { + a: T, + b: [T; 5], +} + +#[test] +fn test_cols_ref() { + assert_eq!( + TestColsRef::::width::(), + TestColsRefMut::::width::() + ); + const WIDTH: usize = TestColsRef::::width::(); + let mut input = vec![0; WIDTH]; + let mut cols: TestColsRefMut = TestColsRefMut::from::(&mut input); + + *cols.single_field_element = 1; + cols.array_of_t[0] = 2; + cols.nested_array_of_t[[0, 0]] = 3; + *cols.cols_struct.a = 4; + cols.cols_struct.b[0] = 5; + cols.cols_struct.c.a = 6; + cols.cols_struct.c.b[0] = 7; + cols.array_of_aligned_borrow[0].a = 8; + cols.array_of_aligned_borrow[0].b[0] = 9; + cols.nested_array_of_aligned_borrow[[0, 0]].a = 10; + cols.nested_array_of_aligned_borrow[[0, 0]].b[0] = 11; + + let cols: TestColsRef = TestColsRef::from::(&input); + println!("{:?}", cols); + assert_eq!(*cols.single_field_element, 1); + assert_eq!(cols.array_of_t[0], 2); + assert_eq!(cols.nested_array_of_t[[0, 0]], 3); + assert_eq!(*cols.cols_struct.a, 4); + assert_eq!(cols.cols_struct.b[0], 5); + assert_eq!(cols.cols_struct.c.a, 6); + assert_eq!(cols.cols_struct.c.b[0], 7); + assert_eq!(cols.array_of_aligned_borrow[0].a, 8); + assert_eq!(cols.array_of_aligned_borrow[0].b[0], 9); + assert_eq!(cols.nested_array_of_aligned_borrow[[0, 0]].a, 10); + assert_eq!(cols.nested_array_of_aligned_borrow[[0, 0]].b[0], 11); +} + +/* + * For reference, this is what the ColsRef macro expands to. + * The `cargo expand` tool is helpful for understanding how the ColsRef macro works. + * See https://github.com/dtolnay/cargo-expand + +#[derive(Debug, Clone)] +struct TestColsRef<'a, T> { + pub single_field_element: &'a T, + pub array_of_t: ndarray::ArrayView1<'a, T>, + pub nested_array_of_t: ndarray::ArrayView2<'a, T>, + pub cols_struct: TestSubColsRef<'a, T>, + pub array_of_aligned_borrow: ndarray::ArrayView1<'a, TestAlignedBorrow>, + pub nested_array_of_aligned_borrow: ndarray::ArrayView2<'a, TestAlignedBorrow>, +} + +impl<'a, T> TestColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let single_field_element_length = 1; + let (single_field_element_slice, slice) = slice + .split_at(single_field_element_length); + let (array_of_t_slice, slice) = slice.split_at(1 * C::N); + let array_of_t_slice = ndarray::ArrayView1::from_shape((C::N), array_of_t_slice) + .unwrap(); + let (nested_array_of_t_slice, slice) = slice.split_at(1 * C::N * C::N); + let nested_array_of_t_slice = ndarray::ArrayView2::from_shape( + (C::N, C::N), + nested_array_of_t_slice, + ) + .unwrap(); + let cols_struct_length = >::width::(); + let (cols_struct_slice, slice) = slice.split_at(cols_struct_length); + let cols_struct_slice = >::from::(cols_struct_slice); + let (array_of_aligned_borrow_slice, slice) = slice + .split_at(>::width() * C::N); + let array_of_aligned_borrow_slice: &[TestAlignedBorrow] = unsafe { + &*(array_of_aligned_borrow_slice as *const [T] + as *const [TestAlignedBorrow]) + }; + let array_of_aligned_borrow_slice = ndarray::ArrayView1::from_shape( + (C::N), + array_of_aligned_borrow_slice, + ) + .unwrap(); + let (nested_array_of_aligned_borrow_slice, slice) = slice + .split_at(>::width() * C::N * C::N); + let nested_array_of_aligned_borrow_slice: &[TestAlignedBorrow] = unsafe { + &*(nested_array_of_aligned_borrow_slice as *const [T] + as *const [TestAlignedBorrow]) + }; + let nested_array_of_aligned_borrow_slice = ndarray::ArrayView2::from_shape( + (C::N, C::N), + nested_array_of_aligned_borrow_slice, + ) + .unwrap(); + Self { + single_field_element: &single_field_element_slice[0], + array_of_t: array_of_t_slice, + nested_array_of_t: nested_array_of_t_slice, + cols_struct: cols_struct_slice, + array_of_aligned_borrow: array_of_aligned_borrow_slice, + nested_array_of_aligned_borrow: nested_array_of_aligned_borrow_slice, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::N + 1 * C::N * C::N + >::width::() + + >::width() * C::N + + >::width() * C::N * C::N + } +} + +impl<'b, T> TestColsRef<'b, T> { + pub fn from_mut<'a, C: TestConfig>(other: &'b TestColsRefMut<'a, T>) -> Self { + Self { + single_field_element: &other.single_field_element, + array_of_t: other.array_of_t.view(), + nested_array_of_t: other.nested_array_of_t.view(), + cols_struct: >::from_mut::(&other.cols_struct), + array_of_aligned_borrow: other.array_of_aligned_borrow.view(), + nested_array_of_aligned_borrow: other.nested_array_of_aligned_borrow.view(), + } + } +} + +#[derive(Debug)] +struct TestColsRefMut<'a, T> { + pub single_field_element: &'a mut T, + pub array_of_t: ndarray::ArrayViewMut1<'a, T>, + pub nested_array_of_t: ndarray::ArrayViewMut2<'a, T>, + pub cols_struct: TestSubColsRefMut<'a, T>, + pub array_of_aligned_borrow: ndarray::ArrayViewMut1<'a, TestAlignedBorrow>, + pub nested_array_of_aligned_borrow: ndarray::ArrayViewMut2<'a, TestAlignedBorrow>, +} + +impl<'a, T> TestColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let single_field_element_length = 1; + let (single_field_element_slice, slice) = slice + .split_at_mut(single_field_element_length); + let (array_of_t_slice, slice) = slice.split_at_mut(1 * C::N); + let array_of_t_slice = ndarray::ArrayViewMut1::from_shape( + (C::N), + array_of_t_slice, + ) + .unwrap(); + let (nested_array_of_t_slice, slice) = slice.split_at_mut(1 * C::N * C::N); + let nested_array_of_t_slice = ndarray::ArrayViewMut2::from_shape( + (C::N, C::N), + nested_array_of_t_slice, + ) + .unwrap(); + let cols_struct_length = >::width::(); + let (cols_struct_slice, slice) = slice.split_at_mut(cols_struct_length); + let cols_struct_slice = >::from::(cols_struct_slice); + let (array_of_aligned_borrow_slice, slice) = slice + .split_at_mut(>::width() * C::N); + let array_of_aligned_borrow_slice: &mut [TestAlignedBorrow] = unsafe { + &mut *(array_of_aligned_borrow_slice as *mut [T] + as *mut [TestAlignedBorrow]) + }; + let array_of_aligned_borrow_slice = ndarray::ArrayViewMut1::from_shape( + (C::N), + array_of_aligned_borrow_slice, + ) + .unwrap(); + let (nested_array_of_aligned_borrow_slice, slice) = slice + .split_at_mut(>::width() * C::N * C::N); + let nested_array_of_aligned_borrow_slice: &mut [TestAlignedBorrow] = unsafe { + &mut *(nested_array_of_aligned_borrow_slice as *mut [T] + as *mut [TestAlignedBorrow]) + }; + let nested_array_of_aligned_borrow_slice = ndarray::ArrayViewMut2::from_shape( + (C::N, C::N), + nested_array_of_aligned_borrow_slice, + ) + .unwrap(); + Self { + single_field_element: &mut single_field_element_slice[0], + array_of_t: array_of_t_slice, + nested_array_of_t: nested_array_of_t_slice, + cols_struct: cols_struct_slice, + array_of_aligned_borrow: array_of_aligned_borrow_slice, + nested_array_of_aligned_borrow: nested_array_of_aligned_borrow_slice, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::N + 1 * C::N * C::N + >::width::() + + >::width() * C::N + + >::width() * C::N * C::N + } +} + +#[derive(Debug, Clone)] +struct TestSubColsRef<'a, T> { + pub a: &'a T, + pub b: ndarray::ArrayView1<'a, T>, + pub c: &'a TestAlignedBorrow, +} + +impl<'a, T> TestSubColsRef<'a, T> { + pub fn from(slice: &'a [T]) -> Self { + let a_length = 1; + let (a_slice, slice) = slice.split_at(a_length); + let (b_slice, slice) = slice.split_at(1 * C::M); + let b_slice = ndarray::ArrayView1::from_shape((C::M), b_slice).unwrap(); + let c_length = >::width(); + let (c_slice, slice) = slice.split_at(c_length); + Self { + a: &a_slice[0], + b: b_slice, + c: { + use core::borrow::Borrow; + c_slice.borrow() + }, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::M + >::width() + } +} + +impl<'b, T> TestSubColsRef<'b, T> { + pub fn from_mut<'a, C: TestConfig>(other: &'b TestSubColsRefMut<'a, T>) -> Self { + Self { + a: &other.a, + b: other.b.view(), + c: other.c, + } + } +} + +#[derive(Debug)] +struct TestSubColsRefMut<'a, T> { + pub a: &'a mut T, + pub b: ndarray::ArrayViewMut1<'a, T>, + pub c: &'a mut TestAlignedBorrow, +} + +impl<'a, T> TestSubColsRefMut<'a, T> { + pub fn from(slice: &'a mut [T]) -> Self { + let a_length = 1; + let (a_slice, slice) = slice.split_at_mut(a_length); + let (b_slice, slice) = slice.split_at_mut(1 * C::M); + let b_slice = ndarray::ArrayViewMut1::from_shape((C::M), b_slice).unwrap(); + let c_length = >::width(); + let (c_slice, slice) = slice.split_at_mut(c_length); + Self { + a: &mut a_slice[0], + b: b_slice, + c: { + use core::borrow::BorrowMut; + c_slice.borrow_mut() + }, + } + } + pub const fn width() -> usize { + 0 + 1 + 1 * C::M + >::width() + } +} +*/ diff --git a/crates/circuits/sha256-air/Cargo.toml b/crates/circuits/sha2-air/Cargo.toml similarity index 66% rename from crates/circuits/sha256-air/Cargo.toml rename to crates/circuits/sha2-air/Cargo.toml index c376a1ffdd..11e8d76654 100644 --- a/crates/circuits/sha256-air/Cargo.toml +++ b/crates/circuits/sha2-air/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "openvm-sha256-air" +name = "openvm-sha2-air" version.workspace = true authors.workspace = true edition.workspace = true @@ -7,8 +7,13 @@ edition.workspace = true [dependencies] openvm-circuit-primitives = { workspace = true } openvm-stark-backend = { workspace = true } +openvm-circuit-primitives-derive = { workspace = true } + sha2 = { version = "0.10", features = ["compress"] } rand.workspace = true +ndarray.workspace = true +num_enum = { workspace = true } +itertools = { workspace = true } [dev-dependencies] openvm-stark-sdk = { workspace = true } @@ -16,4 +21,4 @@ openvm-circuit = { workspace = true, features = ["test-utils"] } [features] default = ["parallel"] -parallel = ["openvm-stark-backend/parallel"] +parallel = ["openvm-stark-backend/parallel"] \ No newline at end of file diff --git a/crates/circuits/sha2-air/src/air.rs b/crates/circuits/sha2-air/src/air.rs new file mode 100644 index 0000000000..13e5ff71d5 --- /dev/null +++ b/crates/circuits/sha2-air/src/air.rs @@ -0,0 +1,646 @@ +use std::{cmp::max, iter::once, marker::PhantomData}; + +use ndarray::s; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupBus, + encoder::Encoder, + utils::{not, select}, + SubAir, +}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use super::{ + big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, + small_sig1_field, +}; +use crate::{ + constraint_word_addition, word_into_u16_limbs, Sha2BlockHasherSubairConfig, Sha2DigestColsRef, + Sha2RoundColsRef, +}; + +/// Expects the message to be padded to a multiple of C::BLOCK_WORDS * C::WORD_BITS bits +#[derive(Clone, Debug)] +pub struct Sha2BlockHasherSubAir { + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + pub row_idx_encoder: Encoder, + /// Internal bus for self-interactions in this AIR. + _phantom: PhantomData, +} + +impl Sha2BlockHasherSubAir { + pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus) -> Self { + Self { + bitwise_lookup_bus, + row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), /* + 1 for dummy + * (padding) rows */ + _phantom: PhantomData, + } + } +} + +impl BaseAir for Sha2BlockHasherSubAir { + fn width(&self) -> usize { + C::SUBAIR_WIDTH + } +} + +impl SubAir + for Sha2BlockHasherSubAir +{ + /// The start column for the sub-air to use + type AirContext<'a> + = usize + where + Self: 'a, + AB: 'a, + ::Var: 'a, + ::Expr: 'a; + + fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>) + where + AB::Var: 'a, + AB::Expr: 'a, + { + self.eval_row(builder, start_col); + self.eval_transitions(builder, start_col); + } +} + +impl Sha2BlockHasherSubAir { + /// Implements the single row constraints (i.e. imposes constraints only on local) + /// Implements some sanity constraints on the row index, flags, and work variables + fn eval_row(&self, builder: &mut AB, start_col: usize) { + let main = builder.main(); + let local = main.row_slice(0); + + // Doesn't matter which column struct we use here as we are only interested in the common + // columns + let local_cols: Sha2DigestColsRef = + Sha2DigestColsRef::from::(&local[start_col..start_col + C::SUBAIR_DIGEST_WIDTH]); + let flags = &local_cols.flags; + builder.assert_bool(*flags.is_round_row); + builder.assert_bool(*flags.is_first_4_rows); + builder.assert_bool(*flags.is_digest_row); + builder.assert_bool(*flags.is_round_row + *flags.is_digest_row); + builder.assert_bool(*flags.is_last_block); + + self.row_idx_encoder + .eval(builder, local_cols.flags.row_idx.to_slice().unwrap()); + builder.assert_one(self.row_idx_encoder.contains_flag_range::( + local_cols.flags.row_idx.to_slice().unwrap(), + 0..=C::ROWS_PER_BLOCK, + )); + builder.assert_eq( + self.row_idx_encoder + .contains_flag_range::(local_cols.flags.row_idx.to_slice().unwrap(), 0..=3), + *flags.is_first_4_rows, + ); + builder.assert_eq( + self.row_idx_encoder.contains_flag_range::( + local_cols.flags.row_idx.to_slice().unwrap(), + 0..=C::ROUND_ROWS - 1, + ), + *flags.is_round_row, + ); + builder.assert_eq( + self.row_idx_encoder.contains_flag::( + local_cols.flags.row_idx.to_slice().unwrap(), + &[C::ROUND_ROWS], + ), + *flags.is_digest_row, + ); + // If padding row we want the row_idx to be C::ROWS_PER_BLOCK + builder.assert_eq( + self.row_idx_encoder.contains_flag::( + local_cols.flags.row_idx.to_slice().unwrap(), + &[C::ROWS_PER_BLOCK], + ), + flags.is_padding_row(), + ); + + // Constrain a, e, being composed of bits: we make sure a and e are always in the same place + // in the trace matrix Note: this has to be true for every row, even padding rows + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_BITS { + builder.assert_bool(local_cols.hash.a[[i, j]]); + builder.assert_bool(local_cols.hash.e[[i, j]]); + } + } + } + + /// Implements constraints for a digest row that ensure proper state transitions between blocks + /// This validates that: + /// The work variables are correctly initialized for the next message block + /// For the last message block, the initial state matches SHA_H constants + fn eval_digest_row( + &self, + builder: &mut AB, + local: Sha2RoundColsRef, + next: Sha2DigestColsRef, + ) { + // Check that if this is the last row of a message or an inpadding row, the hash should be + // the [SHA_H] + for i in 0..C::ROUNDS_PER_ROW { + let a = next.hash.a.row(i).mapv(|x| x.into()).to_vec(); + let e = next.hash.e.row(i).mapv(|x| x.into()).to_vec(); + + for j in 0..C::WORD_U16S { + let a_limb = compose::(&a[j * 16..(j + 1) * 16], 1); + let e_limb = compose::(&e[j * 16..(j + 1) * 16], 1); + + // If it is a padding row or the last row of a message, the `hash` should be the + // [SHA_H] + builder + .when( + next.flags.is_padding_row() + + *next.flags.is_last_block * *next.flags.is_digest_row, + ) + .assert_eq( + a_limb, + AB::Expr::from_canonical_u32( + word_into_u16_limbs::(C::get_h()[C::ROUNDS_PER_ROW - i - 1])[j], + ), + ); + + builder + .when( + next.flags.is_padding_row() + + *next.flags.is_last_block * *next.flags.is_digest_row, + ) + .assert_eq( + e_limb, + AB::Expr::from_canonical_u32( + word_into_u16_limbs::(C::get_h()[C::ROUNDS_PER_ROW - i + 3])[j], + ), + ); + } + } + + // Check if last row of a non-last block, the `hash` should be equal to the final hash of + // the current block + for i in 0..C::ROUNDS_PER_ROW { + let prev_a = next.hash.a.row(i).mapv(|x| x.into()).to_vec(); + let prev_e = next.hash.e.row(i).mapv(|x| x.into()).to_vec(); + let cur_a = next + .final_hash + .row(C::ROUNDS_PER_ROW - i - 1) + .mapv(|x| x.into()); + + let cur_e = next + .final_hash + .row(C::ROUNDS_PER_ROW - i + 3) + .mapv(|x| x.into()); + for j in 0..C::WORD_U8S { + let prev_a_limb = compose::(&prev_a[j * 8..(j + 1) * 8], 1); + let prev_e_limb = compose::(&prev_e[j * 8..(j + 1) * 8], 1); + + builder + .when(not(*next.flags.is_last_block) * *next.flags.is_digest_row) + .assert_eq(prev_a_limb, cur_a[j].clone()); + + builder + .when(not(*next.flags.is_last_block) * *next.flags.is_digest_row) + .assert_eq(prev_e_limb, cur_e[j].clone()); + } + } + + // Assert that the previous hash + work vars == final hash. + // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]` + // where addition is done modulo 2^32 + for i in 0..C::HASH_WORDS { + let mut carry = AB::Expr::ZERO; + for j in 0..C::WORD_U16S { + let work_var_limb = if i < C::ROUNDS_PER_ROW { + compose::( + local + .work_vars + .a + .slice(s![C::ROUNDS_PER_ROW - 1 - i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) + } else { + compose::( + local + .work_vars + .e + .slice(s![C::ROUNDS_PER_ROW + 3 - i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) + }; + let final_hash_limb = compose::( + next.final_hash + .slice(s![i, j * 2..(j + 1) * 2]) + .as_slice() + .unwrap(), + 8, + ); + + carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse()) + * (next.prev_hash[[i, j]] + work_var_limb + carry - final_hash_limb); + builder + .when(*next.flags.is_digest_row) + .assert_bool(carry.clone()); + } + // constrain the final hash limbs two at a time since we can do two checks per + // interaction + for chunk in next.final_hash.row(i).as_slice().unwrap().chunks(2) { + self.bitwise_lookup_bus + .send_range(chunk[0], chunk[1]) + .eval(builder, *next.flags.is_digest_row); + } + } + } + + fn eval_transitions(&self, builder: &mut AB, start_col: usize) { + let main = builder.main(); + let local = main.row_slice(0); + let next = main.row_slice(1); + + // Doesn't matter what column structs we use here + let local_cols: Sha2RoundColsRef = + Sha2RoundColsRef::from::(&local[start_col..start_col + C::SUBAIR_ROUND_WIDTH]); + let next_cols: Sha2RoundColsRef = + Sha2RoundColsRef::from::(&next[start_col..start_col + C::SUBAIR_ROUND_WIDTH]); + + let local_is_padding_row = local_cols.flags.is_padding_row(); + // Note that there will always be a padding row in the trace since the unpadded height is a + // multiple of 17 (SHA-256) or 21 (SHA-512, SHA-384). So the next row is padding iff the + // current block is the last block in the trace. + let next_is_padding_row = next_cols.flags.is_padding_row(); + + // We check that the very last block has `is_last_block` set to true, which guarantees that + // there is at least one complete message. If other digest rows have `is_last_block` set to + // true, then the trace will be interpreted as containing multiple messages. + builder + .when(next_is_padding_row.clone()) + .when(*next_cols.flags.is_digest_row) + .assert_one(*next_cols.flags.is_last_block); + // If we are in a round row, the next row cannot be a padding row + builder + .when(*local_cols.flags.is_round_row) + .assert_zero(next_is_padding_row.clone()); + // The first row must be a round row + builder + .when_first_row() + .assert_one(*local_cols.flags.is_round_row); + // If we are in a padding row, the next row must also be a padding row + builder + .when_transition() + .when(local_is_padding_row.clone()) + .assert_one(next_is_padding_row.clone()); + // If we are in a digest row, the next row cannot be a digest row + builder + .when(*local_cols.flags.is_digest_row) + .assert_zero(*next_cols.flags.is_digest_row); + // Constrain how much the row index changes by + // round->round: 1 + // round->digest: 1 + // digest->round: -C::ROUND_ROWS + // digest->padding: 1 + // padding->padding: 0 + // Other transitions are not allowed by the above constraints + let delta = *local_cols.flags.is_round_row * AB::Expr::ONE + + *local_cols.flags.is_digest_row + * *next_cols.flags.is_round_row + * AB::Expr::from_canonical_usize(C::ROUND_ROWS) + * AB::Expr::NEG_ONE + + *local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE; + + let local_row_idx = self.row_idx_encoder.flag_with_val::( + local_cols.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + let next_row_idx = self.row_idx_encoder.flag_with_val::( + next_cols.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + + builder + .when_transition() + .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone()); + builder.when_first_row().assert_zero(local_row_idx); + + // Constrain the global block index + // We set the global block index to 0 for padding rows + // Starting with 1 so it is not the same as the padding rows + + // Global block index is 1 on first row + builder + .when_first_row() + .assert_one(*local_cols.flags.global_block_idx); + + // Global block index is constant on all rows in a block + builder.when(*local_cols.flags.is_round_row).assert_eq( + *local_cols.flags.global_block_idx, + *next_cols.flags.global_block_idx, + ); + // Global block index increases by 1 between blocks + builder + .when_transition() + .when(*local_cols.flags.is_digest_row) + .when(*next_cols.flags.is_round_row) + .assert_eq( + *local_cols.flags.global_block_idx + AB::Expr::ONE, + *next_cols.flags.global_block_idx, + ); + // Global block index is 0 on padding rows + builder + .when(local_is_padding_row.clone()) + .assert_zero(*local_cols.flags.global_block_idx); + + // Constrain the local block index + // We set the local block index to 0 for padding rows + + // Local block index is constant on all rows in a block + // and its value on padding rows is equal to its value on the first block + builder + .when(not(*local_cols.flags.is_digest_row)) + .assert_eq( + *local_cols.flags.local_block_idx, + *next_cols.flags.local_block_idx, + ); + // Local block index increases by 1 between blocks in the same message + builder + .when(*local_cols.flags.is_digest_row) + .when(not(*local_cols.flags.is_last_block)) + .assert_eq( + *local_cols.flags.local_block_idx + AB::Expr::ONE, + *next_cols.flags.local_block_idx, + ); + // Local block index is 0 on padding rows + // Combined with the above, this means that the local block index is 0 in the first block + builder + .when(*local_cols.flags.is_digest_row) + .when(*local_cols.flags.is_last_block) + .assert_zero(*next_cols.flags.local_block_idx); + + self.eval_message_schedule(builder, local_cols.clone(), next_cols.clone()); + self.eval_work_vars(builder, local_cols.clone(), next_cols); + let next_cols: Sha2DigestColsRef = + Sha2DigestColsRef::from::(&next[start_col..start_col + C::SUBAIR_DIGEST_WIDTH]); + self.eval_digest_row(builder, local_cols, next_cols); + } + + /// Constrain the message schedule additions for `next` row + /// Note: For every addition we need to constrain the following for each of [WORD_U16S] limbs + /// sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] + carry_w[t][i-1] - + /// carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] + fn eval_message_schedule<'a, AB: InteractionBuilder>( + &self, + builder: &mut AB, + local: Sha2RoundColsRef<'a, AB::Var>, + next: Sha2RoundColsRef<'a, AB::Var>, + ) { + // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx + let w = ndarray::concatenate( + ndarray::Axis(0), + &[local.message_schedule.w, next.message_schedule.w], + ) + .unwrap(); + + // Constrain `w_3` for `next` row + for i in 0..C::ROUNDS_PER_ROW - 1 { + // here we constrain the w_3 of the i_th word of the next row + // w_3 of next is w[i+4-3] = w[i+1] + let w_3 = w.row(i + 1).mapv(|x| x.into()).to_vec(); + let expected_w_3 = next.schedule_helper.w_3.row(i); + for j in 0..C::WORD_U16S { + let w_3_limb = compose::(&w_3[j * 16..(j + 1) * 16], 1); + builder + .when(*local.flags.is_round_row) + .assert_eq(w_3_limb, expected_w_3[j].into()); + } + } + + // Constrain intermed for `next` row + // We will only constrain intermed_12 for rows [3, C::ROUND_ROWS - 2], and let it + // unconstrained for other rows Other rows should put the needed value in + // intermed_12 to make the below summation constraint hold + let is_row_intermed_12 = self.row_idx_encoder.contains_flag_range::( + next.flags.row_idx.to_slice().unwrap(), + 3..=C::ROUND_ROWS - 2, + ); + // We will only constrain intermed_8 for rows [2, C::ROUND_ROWS - 3], and let it + // unconstrained for other rows + let is_row_intermed_8 = self.row_idx_encoder.contains_flag_range::( + next.flags.row_idx.to_slice().unwrap(), + 2..=C::ROUND_ROWS - 3, + ); + for i in 0..C::ROUNDS_PER_ROW { + // w_idx + let w_idx = w.row(i).mapv(|x| x.into()).to_vec(); + // sig_0(w_{idx+1}) + let sig_w = small_sig0_field::(w.row(i + 1).as_slice().unwrap()); + for j in 0..C::WORD_U16S { + let w_idx_limb = compose::(&w_idx[j * 16..(j + 1) * 16], 1); + let sig_w_limb = compose::(&sig_w[j * 16..(j + 1) * 16], 1); + + // We would like to constrain this only on round rows, but we can't do a conditional + // check because the degree is already 3. So we must fill in `intermed_4` with dummy + // values on the first round row and the digest row (rows 0 and 16 for SHA-256) to + // ensure the constraint holds on these rows. + builder.when_transition().assert_eq( + next.schedule_helper.intermed_4[[i, j]], + w_idx_limb + sig_w_limb, + ); + + builder.when(is_row_intermed_8.clone()).assert_eq( + next.schedule_helper.intermed_8[[i, j]], + local.schedule_helper.intermed_4[[i, j]], + ); + + builder.when(is_row_intermed_12.clone()).assert_eq( + next.schedule_helper.intermed_12[[i, j]], + local.schedule_helper.intermed_8[[i, j]], + ); + } + } + + // Constrain the message schedule additions for `next` row + for i in 0..C::ROUNDS_PER_ROW { + // Note, here by w_{t} we mean the i_th word of the `next` row + // w_{t-7} + let w_7 = if i < 3 { + local.schedule_helper.w_3.row(i).mapv(|x| x.into()).to_vec() + } else { + let w_3 = w.row(i - 3).mapv(|x| x.into()).to_vec(); + (0..C::WORD_U16S) + .map(|j| compose::(&w_3[j * 16..(j + 1) * 16], 1)) + .collect::>() + }; + // sig_0(w_{t-15}) + w_{t-16} + let intermed_16 = local.schedule_helper.intermed_12.row(i).mapv(|x| x.into()); + + let carries = (0..C::WORD_U16S) + .map(|j| { + next.message_schedule.carry_or_buffer[[i, j * 2]] + + AB::Expr::TWO * next.message_schedule.carry_or_buffer[[i, j * 2 + 1]] + }) + .collect::>(); + + // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}` + // We would like to constrain this only on rows 4..C::ROUND_ROWS, but we can't do a + // conditional check because the degree of sum is already 3 So we must fill + // in `intermed_12` with dummy values on rows 0..3 and C::ROUND_ROWS-1 and C::ROUND_ROWS + // to ensure the constraint holds on rows 0..4 and C::ROUND_ROWS. Note that + // the dummy value goes in the previous row to make the current row's constraint hold. + constraint_word_addition::<_, C>( + // Note: here we can't do a conditional check because the degree of sum is already + // 3 + &mut builder.when_transition(), + &[&small_sig1_field::( + w.row(i + 2).as_slice().unwrap(), + )], + &[&w_7, intermed_16.as_slice().unwrap()], + w.row(i + 4).as_slice().unwrap(), + &carries, + ); + + for j in 0..C::WORD_U16S { + // When on rows 4..C::ROUND_ROWS message schedule carries should be 0 or 1 + let is_row_4_or_more = *next.flags.is_round_row - *next.flags.is_first_4_rows; + builder + .when(is_row_4_or_more.clone()) + .assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2]]); + builder + .when(is_row_4_or_more) + .assert_bool(next.message_schedule.carry_or_buffer[[i, j * 2 + 1]]); + } + // Constrain w being composed of bits + for j in 0..C::WORD_BITS { + builder + .when(*next.flags.is_round_row) + .assert_bool(next.message_schedule.w[[i, j]]); + } + } + } + + /// Constrain the work vars on `next` row according to the sha documentation + /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] + fn eval_work_vars<'a, AB: InteractionBuilder>( + &self, + builder: &mut AB, + local: Sha2RoundColsRef<'a, AB::Var>, + next: Sha2RoundColsRef<'a, AB::Var>, + ) { + let a = + ndarray::concatenate(ndarray::Axis(0), &[local.work_vars.a, next.work_vars.a]).unwrap(); + let e = + ndarray::concatenate(ndarray::Axis(0), &[local.work_vars.e, next.work_vars.e]).unwrap(); + + for i in 0..C::ROUNDS_PER_ROW { + for j in 0..C::WORD_U16S { + // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in + // [0, 2^8) is enough to prevent overflow and ensure the soundness + // of the addition we want to check + self.bitwise_lookup_bus + .send_range( + local.work_vars.carry_a[[i, j]], + local.work_vars.carry_e[[i, j]], + ) + .eval(builder, *local.flags.is_round_row); + } + + let w_limbs = (0..C::WORD_U16S) + .map(|j| { + compose::( + next.message_schedule + .w + .slice(s![i, j * 16..(j + 1) * 16]) + .as_slice() + .unwrap(), + 1, + ) * *next.flags.is_round_row + }) + .collect::>(); + + let k_limbs = (0..C::WORD_U16S) + .map(|j| { + self.row_idx_encoder.flag_with_val::( + next.flags.row_idx.to_slice().unwrap(), + &(0..C::ROUND_ROWS) + .map(|rw_idx| { + ( + rw_idx, + word_into_u16_limbs::( + C::get_k()[rw_idx * C::ROUNDS_PER_ROW + i], + )[j] as usize, + ) + }) + .collect::>(), + ) + }) + .collect::>(); + + let row_idx = self.row_idx_encoder.flag_with_val::( + next.flags.row_idx.to_slice().unwrap(), + &(0..=C::ROWS_PER_BLOCK).map(|i| (i, i)).collect::>(), + ); + let row_idx = format!("{:?}", row_idx); + if row_idx == "0" { + println!("row {} w_limbs: {:?}", row_idx, w_limbs); + println!("row {} k_limbs: {:?}", row_idx, k_limbs); + } + + // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)` + // We have to enforce this constraint on all rows since the degree of the constraint is + // already 3. So, we must fill in `carry_a` with dummy values on digest rows + // to ensure the constraint holds. + constraint_word_addition::<_, C>( + builder, + &[ + e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h` + &big_sig1_field::(e.row(i + 3).as_slice().unwrap()), /* sig_1 of + previous `e` */ + &ch_field::( + e.row(i + 3).as_slice().unwrap(), + e.row(i + 2).as_slice().unwrap(), + e.row(i + 1).as_slice().unwrap(), + ), /* Ch of previous `e`, `f`, `g` */ + &big_sig0_field::(a.row(i + 3).as_slice().unwrap()), /* sig_0 of + previous `a` */ + &maj_field::( + a.row(i + 3).as_slice().unwrap(), + a.row(i + 2).as_slice().unwrap(), + a.row(i + 1).as_slice().unwrap(), + ), /* Maj of previous a, b, c */ + ], + &[&w_limbs, &k_limbs], // K and W + a.row(i + 4).as_slice().unwrap(), // new `a` + next.work_vars.carry_a.row(i).as_slice().unwrap(), // carries of addition + ); + + // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W` + // We have to enforce this constraint on all rows since the degree of the constraint is + // already 3. So, we must fill in `carry_e` with dummy values on digest rows + // to ensure the constraint holds. + constraint_word_addition::<_, C>( + builder, + &[ + a.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `d` + e.row(i).mapv(|x| x.into()).as_slice().unwrap(), // previous `h` + &big_sig1_field::(e.row(i + 3).as_slice().unwrap()), /* sig_1 of + previous `e` */ + &ch_field::( + e.row(i + 3).as_slice().unwrap(), + e.row(i + 2).as_slice().unwrap(), + e.row(i + 1).as_slice().unwrap(), + ), /* Ch of previous `e`, `f`, `g` */ + ], + &[&w_limbs, &k_limbs], // K and W + e.row(i + 4).as_slice().unwrap(), // new `e` + next.work_vars.carry_e.row(i).as_slice().unwrap(), // carries of addition + ); + } + } +} diff --git a/crates/circuits/sha2-air/src/columns.rs b/crates/circuits/sha2-air/src/columns.rs new file mode 100644 index 0000000000..bae9e752b2 --- /dev/null +++ b/crates/circuits/sha2-air/src/columns.rs @@ -0,0 +1,195 @@ +//! WARNING: the order of fields in the structs is important, do not change it + +use openvm_circuit_primitives::utils::not; +use openvm_circuit_primitives_derive::ColsRef; +use openvm_stark_backend::p3_field::FieldAlgebra; + +use crate::Sha2BlockHasherSubairConfig; + +/// In each SHA block: +/// - First C::ROUND_ROWS rows use Sha2RoundCols +/// - Final row uses Sha2DigestCols +/// +/// Note that for soundness, we require that there is always a padding row after the last digest row +/// in the trace. Right now, this is true because the unpadded height is a multiple of 17 (SHA-256) +/// or 21 (SHA-512), and thus not a power of 2. +/// +/// Sha2RoundCols and Sha2DigestCols share the same first 3 fields: +/// - flags +/// - work_vars/hash (same type, different name) +/// - schedule_helper +/// +/// This design allows for: +/// 1. Common constraints to work on either struct type by accessing these shared fields +/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional +/// constraints +/// +/// Note that the `Sha2WorkVarsCols` field is used for different purposes in the two structs. +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2RoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub flags: Sha2FlagsCols, + pub work_vars: Sha2WorkVarsCols, + pub schedule_helper: + Sha2MessageHelperCols, + pub message_schedule: Sha2MessageScheduleCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2DigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub flags: Sha2FlagsCols, + /// Will serve as previous hash values for the next block + pub hash: Sha2WorkVarsCols, + pub schedule_helper: + Sha2MessageHelperCols, + /// The actual final hash values of the given block + /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block + pub final_hash: [[T; WORD_U8S]; HASH_WORDS], + /// The final hash of the previous block + /// Note: will be constrained using interactions with the chip itself + pub prev_hash: [[T; WORD_U16S]; HASH_WORDS], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2MessageScheduleCols< + T, + const WORD_BITS: usize, + const ROUNDS_PER_ROW: usize, + const WORD_U8S: usize, +> { + /// The message schedule words as bits + /// The first 16 words will be the message data + pub w: [[T; WORD_BITS]; ROUNDS_PER_ROW], + /// Will be message schedule carries for rows 4..C::ROUND_ROWS and a buffer for rows 0..4 to be + /// used freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells + /// as individual bits + /// Note: carry_or_buffer is left unconstrained on rounds 0..3 + pub carry_or_buffer: [[T; WORD_U8S]; ROUNDS_PER_ROW], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2WorkVarsCols< + T, + const WORD_BITS: usize, + const ROUNDS_PER_ROW: usize, + const WORD_U16S: usize, +> { + /// `a` and `e` after each iteration as 32-bits + pub a: [[T; WORD_BITS]; ROUNDS_PER_ROW], + pub e: [[T; WORD_BITS]; ROUNDS_PER_ROW], + /// The carry's used for addition during each iteration when computing `a` and `e` + pub carry_a: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub carry_e: [[T; WORD_U16S]; ROUNDS_PER_ROW], +} + +/// These are the columns that are used to help with the message schedule additions +/// Note: these need to be correctly assigned for every row even on padding rows +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2MessageHelperCols< + T, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, +> { + /// The following are used to move data forward to constrain the message schedule additions + /// The value of `w` from 3 rounds ago + pub w_3: [[T; WORD_U16S]; ROUNDS_PER_ROW_MINUS_ONE], + /// Here intermediate(i) = w_i + sig_0(w_{i+1}) + /// Intermed_t represents the intermediate t rounds ago + /// This is needed to constrain the message schedule, since we can only constrain on two rows + /// at a time + pub intermed_4: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub intermed_8: [[T; WORD_U16S]; ROUNDS_PER_ROW], + pub intermed_12: [[T; WORD_U16S]; ROUNDS_PER_ROW], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2FlagsCols { + pub is_round_row: T, + /// A flag that indicates if the current row is among the first 4 rows of a block (the message + /// rows) + pub is_first_4_rows: T, + pub is_digest_row: T, + // This flag will always be true. This flag exists because this air was made to support + // multi-block hashes. We chose to leave this flag here since we may support multi-block + // hashes in the future. + pub is_last_block: T, + /// We will encode the row index [0..C::ROWS_PER_BLOCK] using ROW_VAR_CNT cells + pub row_idx: [T; ROW_VAR_CNT], + /// The global index of the current block + pub global_block_idx: T, + /// Will store the index of the current block in the current message starting from 0 + pub local_block_idx: T, +} + +// impl, const ROW_VAR_CNT: usize> +// Sha2FlagsCols +// { +// // This refers to the padding rows that are added to the air to make the trace length a power +// of // 2. Not to be confused with the padding added to messages as part of the SHA hash +// // function. +// pub fn is_not_padding_row(&self) -> O { +// self.is_round_row + self.is_digest_row +// } + +// // This refers to the padding rows that are added to the air to make the trace length a power +// of // 2. Not to be confused with the padding added to messages as part of the SHA hash +// // function. +// pub fn is_padding_row(&self) -> O +// where +// O: FieldAlgebra, +// { +// not(self.is_not_padding_row()) +// } +// } + +impl> Sha2FlagsColsRef<'_, T> { + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_not_padding_row(&self) -> O { + *self.is_round_row + *self.is_digest_row + } + + // This refers to the padding rows that are added to the air to make the trace length a power of + // 2. Not to be confused with the padding added to messages as part of the SHA hash + // function. + pub fn is_padding_row(&self) -> O + where + O: FieldAlgebra, + { + not(self.is_not_padding_row()) + } + + pub fn is_enabled(&self) -> O { + *self.is_round_row + *self.is_digest_row + } +} diff --git a/crates/circuits/sha2-air/src/config.rs b/crates/circuits/sha2-air/src/config.rs new file mode 100644 index 0000000000..4e5d4d5df2 --- /dev/null +++ b/crates/circuits/sha2-air/src/config.rs @@ -0,0 +1,391 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not, Shl, Shr}; + +use crate::{Sha2DigestColsRef, Sha2RoundColsRef}; + +#[repr(u32)] +#[derive(num_enum::TryFromPrimitive, num_enum::IntoPrimitive, Copy, Clone, Debug)] +pub enum Sha2Variant { + Sha256, + Sha512, + Sha384, +} + +pub trait Sha2BlockHasherSubairConfig: Send + Sync + Clone { + // --- Required --- + + type Word: 'static + + Shr + + Shl + + BitAnd + + Not + + BitXor + + BitOr + + RotateRight + + WrappingAdd + + PartialEq + + From + + TryInto + + Copy + + Send + + Sync; + // Differentiate between the SHA-2 variants + const VARIANT: Sha2Variant; + /// Number of bits in a SHA word + const WORD_BITS: usize; + /// Number of words in a SHA block + const BLOCK_WORDS: usize; + /// Number of rows per block + const ROWS_PER_BLOCK: usize; + /// Number of rounds per row. Must divide Self::ROUNDS_PER_BLOCK + const ROUNDS_PER_ROW: usize; + /// Number of rounds per block. Must be a multiple of Self::ROUNDS_PER_ROW + const ROUNDS_PER_BLOCK: usize; + /// Number of words in a SHA hash + const HASH_WORDS: usize; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize; + + /// To optimize the trace generation of invalid rows, we precompute those values. + /// these should be appropriately sized for the config + fn get_invalid_carry_a(round_num: usize) -> &'static [u32]; + fn get_invalid_carry_e(round_num: usize) -> &'static [u32]; + + /// We also store the SHA constants K and H + fn get_k() -> &'static [Self::Word]; + fn get_h() -> &'static [Self::Word]; + + // --- Provided --- + + /// Number of 16-bit limbs in a SHA word + const WORD_U16S: usize = Self::WORD_BITS / 16; + /// Number of 8-bit limbs in a SHA word + const WORD_U8S: usize = Self::WORD_BITS / 8; + /// Number of cells in a SHA block + const BLOCK_U8S: usize = Self::BLOCK_WORDS * Self::WORD_U8S; + /// Number of bits in a SHA block + const BLOCK_BITS: usize = Self::BLOCK_WORDS * Self::WORD_BITS; + /// Number of rows used for the sha rounds + const ROUND_ROWS: usize = Self::ROUNDS_PER_BLOCK / Self::ROUNDS_PER_ROW; + /// Number of rows used for the message + const MESSAGE_ROWS: usize = Self::BLOCK_WORDS / Self::ROUNDS_PER_ROW; + /// Number of rounds per row minus one (needed for one of the column structs) + const ROUNDS_PER_ROW_MINUS_ONE: usize = Self::ROUNDS_PER_ROW - 1; + /// Width of the Sha2RoundCols + const SUBAIR_ROUND_WIDTH: usize = Sha2RoundColsRef::::width::(); + /// Width of the Sha2DigestCols + const SUBAIR_DIGEST_WIDTH: usize = Sha2DigestColsRef::::width::(); + /// Width of the Sha2BlockHasherCols + const SUBAIR_WIDTH: usize = if Self::SUBAIR_ROUND_WIDTH > Self::SUBAIR_DIGEST_WIDTH { + Self::SUBAIR_ROUND_WIDTH + } else { + Self::SUBAIR_DIGEST_WIDTH + }; +} + +#[derive(Clone)] +pub struct Sha256Config; + +#[derive(Clone)] +pub struct Sha512Config; + +#[derive(Clone)] +pub struct Sha384Config; + +impl Sha2BlockHasherSubairConfig for Sha256Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha256; + type Word = u32; + /// Number of bits in a SHA256 word + const WORD_BITS: usize = 32; + /// Number of words in a SHA256 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 17; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 64; + /// Number of words in a SHA256 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 5; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA256_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA256_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u32] { + &SHA256_K + } + fn get_h() -> &'static [u32] { + &SHA256_H + } +} + +pub const SHA256_INVALID_CARRY_A: [[u32; Sha256Config::WORD_U16S]; Sha256Config::ROUNDS_PER_ROW] = [ + [1230919683, 1162494304], + [266373122, 1282901987], + [1519718403, 1008990871], + [923381762, 330807052], +]; +pub const SHA256_INVALID_CARRY_E: [[u32; Sha256Config::WORD_U16S]; Sha256Config::ROUNDS_PER_ROW] = [ + [204933122, 1994683449], + [443873282, 1544639095], + [719953922, 1888246508], + [194580482, 1075725211], +]; + +/// SHA256 constant K's +pub const SHA256_K: [u32; 64] = [ + 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, + 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, + 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, + 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, + 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, + 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, + 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, + 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, +]; +/// SHA256 initial hash values +pub const SHA256_H: [u32; 8] = [ + 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, +]; + +impl Sha2BlockHasherSubairConfig for Sha512Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha512; + type Word = u64; + /// Number of bits in a SHA512 word + const WORD_BITS: usize = 64; + /// Number of words in a SHA512 block + const BLOCK_WORDS: usize = 16; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = 21; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = 4; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = 80; + /// Number of words in a SHA512 hash + const HASH_WORDS: usize = 8; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = 6; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA512_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA512_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u64] { + &SHA512_K + } + fn get_h() -> &'static [u64] { + &SHA512_H + } +} + +pub(crate) const SHA512_INVALID_CARRY_A: [[u32; Sha512Config::WORD_U16S]; + Sha512Config::ROUNDS_PER_ROW] = [ + [55971842, 827997017, 993005918, 512731953], + [227512322, 1697529235, 1936430385, 940122990], + [1939875843, 1173318562, 826201586, 1513494849], + [891955202, 1732283693, 1736658755, 223514501], +]; + +pub(crate) const SHA512_INVALID_CARRY_E: [[u32; Sha512Config::WORD_U16S]; + Sha512Config::ROUNDS_PER_ROW] = [ + [1384427522, 1509509767, 153131516, 102514978], + [1527552003, 1041677071, 837289497, 843522538], + [775188482, 1620184630, 744892564, 892058728], + [1801267202, 1393118048, 1846108940, 830635531], +]; + +/// SHA512 constant K's +pub const SHA512_K: [u64; 80] = [ + 0x428a2f98d728ae22, + 0x7137449123ef65cd, + 0xb5c0fbcfec4d3b2f, + 0xe9b5dba58189dbbc, + 0x3956c25bf348b538, + 0x59f111f1b605d019, + 0x923f82a4af194f9b, + 0xab1c5ed5da6d8118, + 0xd807aa98a3030242, + 0x12835b0145706fbe, + 0x243185be4ee4b28c, + 0x550c7dc3d5ffb4e2, + 0x72be5d74f27b896f, + 0x80deb1fe3b1696b1, + 0x9bdc06a725c71235, + 0xc19bf174cf692694, + 0xe49b69c19ef14ad2, + 0xefbe4786384f25e3, + 0x0fc19dc68b8cd5b5, + 0x240ca1cc77ac9c65, + 0x2de92c6f592b0275, + 0x4a7484aa6ea6e483, + 0x5cb0a9dcbd41fbd4, + 0x76f988da831153b5, + 0x983e5152ee66dfab, + 0xa831c66d2db43210, + 0xb00327c898fb213f, + 0xbf597fc7beef0ee4, + 0xc6e00bf33da88fc2, + 0xd5a79147930aa725, + 0x06ca6351e003826f, + 0x142929670a0e6e70, + 0x27b70a8546d22ffc, + 0x2e1b21385c26c926, + 0x4d2c6dfc5ac42aed, + 0x53380d139d95b3df, + 0x650a73548baf63de, + 0x766a0abb3c77b2a8, + 0x81c2c92e47edaee6, + 0x92722c851482353b, + 0xa2bfe8a14cf10364, + 0xa81a664bbc423001, + 0xc24b8b70d0f89791, + 0xc76c51a30654be30, + 0xd192e819d6ef5218, + 0xd69906245565a910, + 0xf40e35855771202a, + 0x106aa07032bbd1b8, + 0x19a4c116b8d2d0c8, + 0x1e376c085141ab53, + 0x2748774cdf8eeb99, + 0x34b0bcb5e19b48a8, + 0x391c0cb3c5c95a63, + 0x4ed8aa4ae3418acb, + 0x5b9cca4f7763e373, + 0x682e6ff3d6b2b8a3, + 0x748f82ee5defb2fc, + 0x78a5636f43172f60, + 0x84c87814a1f0ab72, + 0x8cc702081a6439ec, + 0x90befffa23631e28, + 0xa4506cebde82bde9, + 0xbef9a3f7b2c67915, + 0xc67178f2e372532b, + 0xca273eceea26619c, + 0xd186b8c721c0c207, + 0xeada7dd6cde0eb1e, + 0xf57d4f7fee6ed178, + 0x06f067aa72176fba, + 0x0a637dc5a2c898a6, + 0x113f9804bef90dae, + 0x1b710b35131c471b, + 0x28db77f523047d84, + 0x32caab7b40c72493, + 0x3c9ebe0a15c9bebc, + 0x431d67c49c100d4c, + 0x4cc5d4becb3e42b6, + 0x597f299cfc657e2a, + 0x5fcb6fab3ad6faec, + 0x6c44198c4a475817, +]; +/// SHA512 initial hash values +pub const SHA512_H: [u64; 8] = [ + 0x6a09e667f3bcc908, + 0xbb67ae8584caa73b, + 0x3c6ef372fe94f82b, + 0xa54ff53a5f1d36f1, + 0x510e527fade682d1, + 0x9b05688c2b3e6c1f, + 0x1f83d9abfb41bd6b, + 0x5be0cd19137e2179, +]; + +impl Sha2BlockHasherSubairConfig for Sha384Config { + // ==== Do not change these constants! ==== + const VARIANT: Sha2Variant = Sha2Variant::Sha384; + type Word = ::Word; + /// Number of bits in a SHA384 word + const WORD_BITS: usize = ::WORD_BITS; + /// Number of words in a SHA384 block + const BLOCK_WORDS: usize = ::BLOCK_WORDS; + /// Number of rows per block + const ROWS_PER_BLOCK: usize = ::ROWS_PER_BLOCK; + /// Number of rounds per row + const ROUNDS_PER_ROW: usize = ::ROUNDS_PER_ROW; + /// Number of rounds per block + const ROUNDS_PER_BLOCK: usize = ::ROUNDS_PER_BLOCK; + /// Number of words in a SHA384 hash + const HASH_WORDS: usize = ::HASH_WORDS; + /// Number of vars needed to encode the row index with [Encoder] + const ROW_VAR_CNT: usize = ::ROW_VAR_CNT; + + fn get_invalid_carry_a(round_num: usize) -> &'static [u32] { + &SHA384_INVALID_CARRY_A[round_num] + } + fn get_invalid_carry_e(round_num: usize) -> &'static [u32] { + &SHA384_INVALID_CARRY_E[round_num] + } + fn get_k() -> &'static [u64] { + &SHA384_K + } + fn get_h() -> &'static [u64] { + &SHA384_H + } +} + +pub(crate) const SHA384_INVALID_CARRY_A: [[u32; Sha384Config::WORD_U16S]; + Sha384Config::ROUNDS_PER_ROW] = [ + [1571481603, 1428841901, 1050676523, 793575075], + [1233315842, 1822329223, 112923808, 1874228927], + [1245603842, 927240770, 1579759431, 70557227], + [195532801, 594312107, 1429379950, 220407092], +]; + +pub(crate) const SHA384_INVALID_CARRY_E: [[u32; Sha384Config::WORD_U16S]; + Sha384Config::ROUNDS_PER_ROW] = [ + [1067980802, 1508061099, 1418826213, 1232569491], + [1453086722, 1702524575, 152427899, 238512408], + [1623674882, 701393097, 1002035664, 4776891], + [1888911362, 184963225, 1151849224, 1034237098], +]; + +/// SHA384 constant K's +pub const SHA384_K: [u64; 80] = SHA512_K; + +/// SHA384 initial hash values +pub const SHA384_H: [u64; 8] = [ + 0xcbbb9d5dc1059ed8, + 0x629a292a367cd507, + 0x9159015a3070dd17, + 0x152fecd8f70e5939, + 0x67332667ffc00b31, + 0x8eb44a8768581511, + 0xdb0c2e0d64f98fa7, + 0x47b5481dbefa4fa4, +]; + +// Needed to avoid compile errors in utils.rs +// not sure why this doesn't inf loop +pub trait RotateRight { + fn rotate_right(self, n: u32) -> Self; +} +impl RotateRight for u32 { + fn rotate_right(self, n: u32) -> Self { + self.rotate_right(n) + } +} +impl RotateRight for u64 { + fn rotate_right(self, n: u32) -> Self { + self.rotate_right(n) + } +} +pub trait WrappingAdd { + fn wrapping_add(self, n: Self) -> Self; +} +impl WrappingAdd for u32 { + fn wrapping_add(self, n: u32) -> Self { + self.wrapping_add(n) + } +} +impl WrappingAdd for u64 { + fn wrapping_add(self, n: u64) -> Self { + self.wrapping_add(n) + } +} diff --git a/crates/circuits/sha256-air/src/lib.rs b/crates/circuits/sha2-air/src/lib.rs similarity index 51% rename from crates/circuits/sha256-air/src/lib.rs rename to crates/circuits/sha2-air/src/lib.rs index 48bdaee5f9..99e254281a 100644 --- a/crates/circuits/sha256-air/src/lib.rs +++ b/crates/circuits/sha2-air/src/lib.rs @@ -1,13 +1,12 @@ -//! Implementation of the SHA256 compression function without padding -//! This this AIR doesn't constrain any of the message padding - mod air; mod columns; +mod config; mod trace; mod utils; pub use air::*; pub use columns::*; +pub use config::*; pub use trace::*; pub use utils::*; diff --git a/crates/circuits/sha2-air/src/tests.rs b/crates/circuits/sha2-air/src/tests.rs new file mode 100644 index 0000000000..e69de29bb2 diff --git a/crates/circuits/sha2-air/src/trace.rs b/crates/circuits/sha2-air/src/trace.rs new file mode 100644 index 0000000000..73feef0477 --- /dev/null +++ b/crates/circuits/sha2-air/src/trace.rs @@ -0,0 +1,820 @@ +use std::{marker::PhantomData, mem, ops::Range}; + +use itertools::Itertools; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + encoder::Encoder, + utils::{compose, next_power_of_two_or_zero}, +}; +use openvm_stark_backend::{ + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::dense::RowMajorMatrix, + p3_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut}, +}; +use sha2::{compress256, compress512, digest::generic_array::GenericArray}; + +use crate::{ + big_sig0, big_sig0_field, big_sig1, big_sig1_field, ch, ch_field, get_flag_pt_array, + le_limbs_into_word, maj, maj_field, set_arrayview_from_u32_slice, small_sig0, small_sig0_field, + small_sig1, small_sig1_field, word_into_bits, word_into_u16_limbs, word_into_u8_limbs, + Sha2BlockHasherSubairConfig, Sha2DigestColsRefMut, Sha2RoundColsRef, Sha2RoundColsRefMut, + Sha2Variant, WrappingAdd, +}; + +/// A helper struct for the SHA-2 trace generation. +/// Also, separates the inner AIR from the trace generation. +pub struct Sha2BlockHasherFillerHelper { + pub row_idx_encoder: Encoder, + _phantom: PhantomData, +} + +impl Default for Sha2BlockHasherFillerHelper { + fn default() -> Self { + Self::new() + } +} + +/// The trace generation of SHA-2 should be done in two passes. +/// The first pass should do `get_block_trace` for every block and generate the invalid rows through +/// `get_default_row` The second pass should go through all the blocks and call +/// `generate_missing_cells` +impl Sha2BlockHasherFillerHelper { + pub fn new() -> Self { + Self { + row_idx_encoder: Encoder::new(C::ROWS_PER_BLOCK + 1, 2, false), + _phantom: PhantomData, + } + } + + /// This function takes the input_message (padding not handled), the previous hash, + /// and returns the new hash after processing the block input + pub fn get_block_hash(prev_hash: &[C::Word], input: Vec) -> Vec { + debug_assert!(prev_hash.len() == C::HASH_WORDS); + debug_assert!(input.len() == C::BLOCK_U8S); + let mut new_hash: [C::Word; 8] = prev_hash.try_into().unwrap(); + match C::VARIANT { + Sha2Variant::Sha256 => { + let input_array = [*GenericArray::::from_slice( + &input, + )]; + let hash_ptr: &mut [u32; 8] = unsafe { std::mem::transmute(&mut new_hash) }; + compress256(hash_ptr, &input_array); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let hash_ptr: &mut [u64; 8] = unsafe { std::mem::transmute(&mut new_hash) }; + let input_array = [*GenericArray::::from_slice( + &input, + )]; + compress512(hash_ptr, &input_array); + } + } + new_hash.to_vec() + } + + /// This function takes a C::BLOCK_BITS-bit chunk of the input message (padding not handled), + /// the previous hash, a flag indicating if it's the last block, the global block index, the + /// local block index, and the buffer values that will be put in rows 0..4. + /// Will populate the given `trace` with the trace of the block, where the width of the trace is + /// `trace_width` and the starting column for the `Sha2Air` is `trace_start_col`. + /// **Note**: this function only generates some of the required trace. Another pass is required, + /// refer to [`Self::generate_missing_cells`] for details. + #[allow(clippy::too_many_arguments)] + pub fn generate_block_trace( + &self, + trace: &mut [F], + trace_width: usize, + trace_start_col: usize, + input: &[C::Word], + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + prev_hash: &[C::Word], + is_last_block: bool, + global_block_idx: u32, + ) { + #[cfg(debug_assertions)] + { + assert!(input.len() == C::BLOCK_WORDS); + assert!(prev_hash.len() == C::HASH_WORDS); + assert!(trace_start_col + C::SUBAIR_WIDTH == trace_width); + assert!(trace.len() == trace_width * C::ROWS_PER_BLOCK); + } + + let get_range = |start: usize, len: usize| -> Range { start..start + len }; + let mut message_schedule = vec![C::Word::from(0); C::ROUNDS_PER_BLOCK]; + message_schedule[..input.len()].copy_from_slice(input); + let mut work_vars = prev_hash.to_vec(); + for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() { + // do the rounds + if i < C::ROUND_ROWS { + let mut cols: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut row[get_range(trace_start_col, C::SUBAIR_ROUND_WIDTH)], + ); + *cols.flags.is_round_row = F::ONE; + *cols.flags.is_first_4_rows = if i < C::MESSAGE_ROWS { F::ONE } else { F::ZERO }; + *cols.flags.is_digest_row = F::ZERO; + *cols.flags.is_last_block = F::from_bool(is_last_block); + cols.flags + .row_idx + .iter_mut() + .zip( + get_flag_pt_array(&self.row_idx_encoder, i) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + *cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); + *cols.flags.local_block_idx = F::from_canonical_u32(0); + + // W_idx = M_idx + if i < C::MESSAGE_ROWS { + for j in 0..C::ROUNDS_PER_ROW { + cols.message_schedule + .w + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(input[i * C::ROUNDS_PER_ROW + j]) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } + // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} + else { + for j in 0..C::ROUNDS_PER_ROW { + let idx = i * C::ROUNDS_PER_ROW + j; + let nums: [C::Word; 4] = [ + small_sig1::(message_schedule[idx - 2]), + message_schedule[idx - 7], + small_sig0::(message_schedule[idx - 15]), + message_schedule[idx - 16], + ]; + let w: C::Word = nums + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + cols.message_schedule + .w + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(w) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + let nums_limbs = nums + .iter() + .map(|x| word_into_u16_limbs::(*x)) + .collect::>(); + let w_limbs = word_into_u16_limbs::(w); + + // fill in the carrys + for k in 0..C::WORD_U16S { + let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]); + if k > 0 { + sum += (cols.message_schedule.carry_or_buffer[[j, k * 2 - 2]] + + F::TWO + * cols.message_schedule.carry_or_buffer[[j, k * 2 - 1]]) + .as_canonical_u32(); + } + let carry = (sum - w_limbs[k]) >> 16; + cols.message_schedule.carry_or_buffer[[j, k * 2]] = + F::from_canonical_u32(carry & 1); + cols.message_schedule.carry_or_buffer[[j, k * 2 + 1]] = + F::from_canonical_u32(carry >> 1); + } + // update the message schedule + message_schedule[idx] = w; + } + } + // fill in the work variables + for j in 0..C::ROUNDS_PER_ROW { + // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx + let t1 = [ + work_vars[7], + big_sig1::(work_vars[4]), + ch::(work_vars[4], work_vars[5], work_vars[6]), + C::get_k()[i * C::ROUNDS_PER_ROW + j], + le_limbs_into_word::( + cols.message_schedule + .w + .row(j) + .map(|f| f.as_canonical_u32()) + .as_slice() + .unwrap(), + ), + ]; + let t1_sum: C::Word = t1 + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + + // t2 = SIG0(a) + maj(a, b, c) + let t2 = [ + big_sig0::(work_vars[0]), + maj::(work_vars[0], work_vars[1], work_vars[2]), + ]; + + let t2_sum: C::Word = t2 + .iter() + .fold(C::Word::from(0), |acc, &num| acc.wrapping_add(num)); + + // e = d + t1 + let e = work_vars[3].wrapping_add(t1_sum); + cols.work_vars + .e + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(e) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + let e_limbs = word_into_u16_limbs::(e); + // a = t1 + t2 + let a = t1_sum.wrapping_add(t2_sum); + cols.work_vars + .a + .row_mut(j) + .iter_mut() + .zip( + word_into_bits::(a) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + let a_limbs = word_into_u16_limbs::(a); + // fill in the carrys + for k in 0..C::WORD_U16S { + let t1_limb = t1 + .iter() + .fold(0, |acc, &num| acc + word_into_u16_limbs::(num)[k]); + let t2_limb = t2 + .iter() + .fold(0, |acc, &num| acc + word_into_u16_limbs::(num)[k]); + + let mut e_limb = t1_limb + word_into_u16_limbs::(work_vars[3])[k]; + let mut a_limb = t1_limb + t2_limb; + if k > 0 { + a_limb += cols.work_vars.carry_a[[j, k - 1]].as_canonical_u32(); + e_limb += cols.work_vars.carry_e[[j, k - 1]].as_canonical_u32(); + } + let carry_a = (a_limb - a_limbs[k]) >> 16; + let carry_e = (e_limb - e_limbs[k]) >> 16; + cols.work_vars.carry_a[[j, k]] = F::from_canonical_u32(carry_a); + cols.work_vars.carry_e[[j, k]] = F::from_canonical_u32(carry_e); + bitwise_lookup_chip.request_range(carry_a, carry_e); + } + + // update working variables + work_vars[7] = work_vars[6]; + work_vars[6] = work_vars[5]; + work_vars[5] = work_vars[4]; + work_vars[4] = e; + work_vars[3] = work_vars[2]; + work_vars[2] = work_vars[1]; + work_vars[1] = work_vars[0]; + work_vars[0] = a; + } + + // filling w_3 and intermed_4 here and the rest later + if i > 0 { + for j in 0..C::ROUNDS_PER_ROW { + let idx = i * C::ROUNDS_PER_ROW + j; + let w_4 = word_into_u16_limbs::(message_schedule[idx - 4]); + let sig_0_w_3 = + word_into_u16_limbs::(small_sig0::(message_schedule[idx - 3])); + cols.schedule_helper + .intermed_4 + .row_mut(j) + .iter_mut() + .zip( + (0..C::WORD_U16S) + .map(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])) + .collect::>(), + ) + .for_each(|(x, y)| *x = y); + if j < C::ROUNDS_PER_ROW - 1 { + let w_3 = message_schedule[idx - 3]; + cols.schedule_helper + .w_3 + .row_mut(j) + .iter_mut() + .zip( + word_into_u16_limbs::(w_3) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } + } + } + // generate the digest row + else { + let mut cols: Sha2DigestColsRefMut = Sha2DigestColsRefMut::from::( + &mut row[get_range(trace_start_col, C::SUBAIR_DIGEST_WIDTH)], + ); + for j in 0..C::ROUNDS_PER_ROW - 1 { + let w_3 = message_schedule[i * C::ROUNDS_PER_ROW + j - 3]; + cols.schedule_helper + .w_3 + .row_mut(j) + .iter_mut() + .zip( + word_into_u16_limbs::(w_3) + .into_iter() + .map(F::from_canonical_u32) + .collect::>(), + ) + .for_each(|(x, y)| *x = y); + } + *cols.flags.is_round_row = F::ZERO; + *cols.flags.is_first_4_rows = F::ZERO; + *cols.flags.is_digest_row = F::ONE; + *cols.flags.is_last_block = F::from_bool(is_last_block); + cols.flags + .row_idx + .iter_mut() + .zip( + get_flag_pt_array(&self.row_idx_encoder, C::ROUND_ROWS) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + + *cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); + + *cols.flags.local_block_idx = F::from_canonical_u32(0); + let final_hash: Vec = (0..C::HASH_WORDS) + .map(|i| work_vars[i].wrapping_add(prev_hash[i])) + .collect(); + let final_hash_limbs: Vec> = final_hash + .iter() + .map(|word| word_into_u8_limbs::(*word)) + .collect(); + // need to ensure final hash limbs are bytes, in order for + // prev_hash[i] + work_vars[i] == final_hash[i] + // to be constrained correctly + for word in final_hash_limbs.iter() { + for chunk in word.chunks(2) { + bitwise_lookup_chip.request_range(chunk[0], chunk[1]); + } + } + cols.final_hash + .iter_mut() + .zip((0..C::HASH_WORDS).flat_map(|i| { + word_into_u8_limbs::(final_hash[i]) + .into_iter() + .map(F::from_canonical_u32) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + cols.prev_hash + .iter_mut() + .zip(prev_hash.iter().flat_map(|f| { + word_into_u16_limbs::(*f) + .into_iter() + .map(F::from_canonical_u32) + .collect::>() + })) + .for_each(|(x, y)| *x = y); + + let hash = if is_last_block { + C::get_h() + .iter() + .map(|x| word_into_bits::(*x)) + .collect::>() + } else { + cols.final_hash + .rows_mut() + .into_iter() + .map(|f| { + le_limbs_into_word::( + f.map(|x| x.as_canonical_u32()).as_slice().unwrap(), + ) + }) + .map(word_into_bits::) + .collect() + } + .into_iter() + .map(|x| x.into_iter().map(F::from_canonical_u32)) + .collect::>(); + + for i in 0..C::ROUNDS_PER_ROW { + cols.hash + .a + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i - 1].clone()) + .for_each(|(x, y)| *x = y); + cols.hash + .e + .row_mut(i) + .iter_mut() + .zip(hash[C::ROUNDS_PER_ROW - i + 3].clone()) + .for_each(|(x, y)| *x = y); + } + } + } + + for i in 0..C::ROWS_PER_BLOCK - 1 { + let rows = &mut trace[i * trace_width..(i + 2) * trace_width]; + let (local, next) = rows.split_at_mut(trace_width); + let mut local_cols: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut local[get_range(trace_start_col, C::SUBAIR_ROUND_WIDTH)], + ); + let mut next_cols: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut next[get_range(trace_start_col, C::SUBAIR_ROUND_WIDTH)], + ); + if i > 0 { + for j in 0..C::ROUNDS_PER_ROW { + next_cols + .schedule_helper + .intermed_8 + .row_mut(j) + .assign(&local_cols.schedule_helper.intermed_4.row(j)); + if (2..C::ROWS_PER_BLOCK - 3).contains(&i) { + next_cols + .schedule_helper + .intermed_12 + .row_mut(j) + .assign(&local_cols.schedule_helper.intermed_8.row(j)); + } + } + } + if i == C::ROWS_PER_BLOCK - 2 { + // `next` is a digest row. + // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and + // `e` hold. + let const_local_cols = Sha2RoundColsRef::::from_mut::(&local_cols); + Self::generate_carry_ae(const_local_cols.clone(), &mut next_cols); + // Fill in row 16's `intermed_4` with dummy values so the message schedule + // constraints holds on that row + Self::generate_intermed_4(const_local_cols, &mut next_cols); + } + if i < C::MESSAGE_ROWS - 1 { + // i is in 0..3. + // Fill in `local.intermed_12` with dummy values so the message schedule constraints + // hold on rows 1..4. + Self::generate_intermed_12( + &mut local_cols, + Sha2RoundColsRef::::from_mut::(&next_cols), + ); + } + } + } + + /// This function will fill in the cells that we couldn't do during the first pass. + /// This function should be called only after `generate_block_trace` was called for all blocks + /// And [`Self::generate_default_row`] is called for all invalid rows + /// Will populate the missing values of `trace`, where the width of the trace is `trace_width` + /// Note: `trace` needs to be the rows 1..C::ROWS_PER_BLOCK of a block and the first row of the + /// next block + pub fn generate_missing_cells( + &self, + trace: &mut [F], + trace_width: usize, + trace_start_col: usize, + ) { + let rows = &mut trace[(C::ROUND_ROWS - 2) * trace_width..(C::ROUND_ROWS + 1) * trace_width]; + let (last_round_row, rows) = rows.split_at_mut(trace_width); + let (digest_row, next_block_first_row) = rows.split_at_mut(trace_width); + let mut cols_last_round_row: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut last_round_row[trace_start_col..trace_start_col + C::SUBAIR_ROUND_WIDTH], + ); + let mut cols_digest_row: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut digest_row[trace_start_col..trace_start_col + C::SUBAIR_ROUND_WIDTH], + ); + let mut cols_next_block_first_row: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::( + &mut next_block_first_row[trace_start_col..trace_start_col + C::SUBAIR_ROUND_WIDTH], + ); + // Fill in the last round row's `intermed_12` with dummy values so the message schedule + // constraints holds on the last round row + Self::generate_intermed_12( + &mut cols_last_round_row, + Sha2RoundColsRef::from_mut::(&cols_digest_row), + ); + // Fill in the digest row's `intermed_12` with dummy values so the message schedule + // constraints holds on the next block's row 0 + Self::generate_intermed_12( + &mut cols_digest_row, + Sha2RoundColsRef::from_mut::(&cols_next_block_first_row), + ); + // Fill in the next block's first row's `intermed_4` with dummy values so the message + // schedule constraints holds on that row + Self::generate_intermed_4( + Sha2RoundColsRef::from_mut::(&cols_digest_row), + &mut cols_next_block_first_row, + ); + } + + /// Fills the `cols` as a padding row + /// Note: we still need to correctly fill in the hash values, carries and intermeds + pub fn generate_default_row(&self, mut cols: Sha2RoundColsRefMut) { + set_arrayview_from_u32_slice( + &mut cols.flags.row_idx, + get_flag_pt_array(&self.row_idx_encoder, C::ROWS_PER_BLOCK), + ); + + // TODO: precompute this + let mut hash = C::get_h() + .iter() + .cloned() + .map(word_into_bits::) + .collect_vec(); + + for i in 0..C::ROUNDS_PER_ROW { + set_arrayview_from_u32_slice( + &mut cols.work_vars.a.row_mut(i), + mem::take(&mut hash[C::ROUNDS_PER_ROW - i - 1]).into_iter(), + ); + set_arrayview_from_u32_slice( + &mut cols.work_vars.e.row_mut(i), + mem::take(&mut hash[C::ROUNDS_PER_ROW - i + 3]).into_iter(), + ); + + set_arrayview_from_u32_slice( + &mut cols.work_vars.carry_a.row_mut(i), + C::get_invalid_carry_a(i).iter().cloned(), + ); + set_arrayview_from_u32_slice( + &mut cols.work_vars.carry_e.row_mut(i), + C::get_invalid_carry_e(i).iter().cloned(), + ); + } + } + + /// The following functions do the calculations in native field since they will be called on + /// padding rows which can overflow and we need to make sure it matches the AIR constraints + /// Puts the correct carries in the `next_row`, the resulting carries can be out of bounds + pub fn generate_carry_ae( + local_cols: Sha2RoundColsRef, + next_cols: &mut Sha2RoundColsRefMut, + ) { + let a = [ + local_cols + .work_vars + .a + .rows() + .into_iter() + .collect::>(), + next_cols.work_vars.a.rows().into_iter().collect::>(), + ] + .concat(); + let e = [ + local_cols + .work_vars + .e + .rows() + .into_iter() + .collect::>(), + next_cols.work_vars.e.rows().into_iter().collect::>(), + ] + .concat(); + for i in 0..C::ROUNDS_PER_ROW { + let cur_a = a[i + 4]; + let sig_a = big_sig0_field::(a[i + 3].as_slice().unwrap()); + let maj_abc = maj_field::( + a[i + 3].as_slice().unwrap(), + a[i + 2].as_slice().unwrap(), + a[i + 1].as_slice().unwrap(), + ); + let d = a[i]; + let cur_e = e[i + 4]; + let sig_e = big_sig1_field::(e[i + 3].as_slice().unwrap()); + let ch_efg = ch_field::( + e[i + 3].as_slice().unwrap(), + e[i + 2].as_slice().unwrap(), + e[i + 1].as_slice().unwrap(), + ); + let h = e[i]; + + let t1 = [h.to_vec(), sig_e, ch_efg.to_vec()]; + let t2 = [sig_a, maj_abc]; + for j in 0..C::WORD_U16S { + let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| { + acc + compose::(&x[j * 16..(j + 1) * 16], 1) + }); + let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| { + acc + compose::(&x[j * 16..(j + 1) * 16], 1) + }); + let d_limb = compose::(&d.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let cur_a_limb = compose::(&cur_a.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let cur_e_limb = compose::(&cur_e.as_slice().unwrap()[j * 16..(j + 1) * 16], 1); + let sum = d_limb + + t1_limb_sum + + if j == 0 { + F::ZERO + } else { + next_cols.work_vars.carry_e[[i, j - 1]] + } + - cur_e_limb; + let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse()); + + let sum = t1_limb_sum + + t2_limb_sum + + if j == 0 { + F::ZERO + } else { + next_cols.work_vars.carry_a[[i, j - 1]] + } + - cur_a_limb; + let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse()); + next_cols.work_vars.carry_e[[i, j]] = carry_e; + next_cols.work_vars.carry_a[[i, j]] = carry_a; + } + } + } + + /// Puts the correct intermed_4 in the `next_row` + pub fn generate_intermed_4( + local_cols: Sha2RoundColsRef, + next_cols: &mut Sha2RoundColsRefMut, + ) { + let w = [ + local_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + next_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + ] + .concat(); + let w_limbs: Vec> = w + .iter() + .map(|x| { + (0..C::WORD_U16S) + .map(|i| compose::(&x.as_slice().unwrap()[i * 16..(i + 1) * 16], 1)) + .collect::>() + }) + .collect(); + for i in 0..C::ROUNDS_PER_ROW { + let sig_w = small_sig0_field::(w[i + 1].as_slice().unwrap()); + let sig_w_limbs: Vec = (0..C::WORD_U16S) + .map(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)) + .collect(); + for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { + next_cols.schedule_helper.intermed_4[[i, j]] = w_limbs[i][j] + *sig_w_limb; + } + } + } + + /// Puts the needed intermed_12 in the `local_row` + pub fn generate_intermed_12( + local_cols: &mut Sha2RoundColsRefMut, + next_cols: Sha2RoundColsRef, + ) { + let w = [ + local_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + next_cols + .message_schedule + .w + .rows() + .into_iter() + .collect::>(), + ] + .concat(); + let w_limbs: Vec> = w + .iter() + .map(|x| { + (0..C::WORD_U16S) + .map(|i| compose::(&x.as_slice().unwrap()[i * 16..(i + 1) * 16], 1)) + .collect::>() + }) + .collect(); + for i in 0..C::ROUNDS_PER_ROW { + // sig_1(w_{t-2}) + let sig_w_2: Vec = (0..C::WORD_U16S) + .map(|j| { + compose::( + &small_sig1_field::(w[i + 2].as_slice().unwrap()) + [j * 16..(j + 1) * 16], + 1, + ) + }) + .collect(); + // w_{t-7} + let w_7 = if i < 3 { + local_cols.schedule_helper.w_3.row(i).to_slice().unwrap() + } else { + w_limbs[i - 3].as_slice() + }; + // w_t + let w_cur = w_limbs[i + 4].as_slice(); + for j in 0..C::WORD_U16S { + let carry = next_cols.message_schedule.carry_or_buffer[[i, j * 2]] + + F::TWO * next_cols.message_schedule.carry_or_buffer[[i, j * 2 + 1]]; + let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j] + + if j > 0 { + next_cols.message_schedule.carry_or_buffer[[i, j * 2 - 2]] + + F::from_canonical_u32(2) + * next_cols.message_schedule.carry_or_buffer[[i, j * 2 - 1]] + } else { + F::ZERO + }; + local_cols.schedule_helper.intermed_12[[i, j]] = -sum; + } + } + } +} + +/// `records` consists of pairs of `(input_block, is_last_block)`. +pub fn generate_trace( + filler_helper: &Sha2BlockHasherFillerHelper, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + width: usize, + records: Vec<(Vec, bool)>, +) -> RowMajorMatrix { + for (input, _) in &records { + debug_assert!(input.len() == C::BLOCK_U8S); + } + + let non_padded_height = records.len() * C::ROWS_PER_BLOCK; + let height = next_power_of_two_or_zero(non_padded_height); + let mut values = F::zero_vec(height * width); + + struct BlockContext { + prev_hash: Vec, // len is C::HASH_WORDS + global_block_idx: u32, + input: Vec, // len is C::BLOCK_U8S + is_last_block: bool, + } + let mut block_ctx: Vec> = Vec::with_capacity(records.len()); + let mut prev_hash = C::get_h().to_vec(); + let mut global_block_idx = 1; + for (input, is_last_block) in records { + block_ctx.push(BlockContext { + prev_hash: prev_hash.clone(), + global_block_idx, + input: input.clone(), + is_last_block, + }); + global_block_idx += 1; + if is_last_block { + prev_hash = C::get_h().to_vec(); + } else { + prev_hash = Sha2BlockHasherFillerHelper::::get_block_hash(&prev_hash, input); + } + } + // first pass + values + .par_chunks_exact_mut(width * C::ROWS_PER_BLOCK) + .zip(block_ctx) + .for_each(|(block, ctx)| { + let BlockContext { + prev_hash, + global_block_idx, + input, + is_last_block, + } = ctx; + let input_words = (0..C::BLOCK_WORDS) + .map(|i| { + le_limbs_into_word::( + &(0..C::WORD_U8S) + .map(|j| input[(i + 1) * C::WORD_U8S - j - 1] as u32) + .collect::>(), + ) + }) + .collect::>(); + filler_helper.generate_block_trace( + block, + width, + 0, + &input_words, + bitwise_lookup_chip.clone(), + &prev_hash, + is_last_block, + global_block_idx, + ); + }); + // second pass: padding rows + values[width * non_padded_height..] + .par_chunks_mut(width) + .for_each(|row| { + let cols: Sha2RoundColsRefMut = Sha2RoundColsRefMut::from::(row); + filler_helper.generate_default_row(cols); + }); + + // second pass: non-padding rows + values[width..] + .par_chunks_mut(width * C::ROWS_PER_BLOCK) + .take(non_padded_height / C::ROWS_PER_BLOCK) + .for_each(|chunk| { + filler_helper.generate_missing_cells(chunk, width, 0); + }); + RowMajorMatrix::new(values, width) +} diff --git a/crates/circuits/sha2-air/src/utils.rs b/crates/circuits/sha2-air/src/utils.rs new file mode 100644 index 0000000000..ac8b33fb73 --- /dev/null +++ b/crates/circuits/sha2-air/src/utils.rs @@ -0,0 +1,313 @@ +use ndarray::ArrayViewMut; +pub use openvm_circuit_primitives::utils::compose; +use openvm_circuit_primitives::{ + encoder::Encoder, + utils::{not, select}, +}; +use openvm_stark_backend::{ + p3_air::AirBuilder, + p3_field::{FieldAlgebra, PrimeField32}, +}; +use rand::{rngs::StdRng, Rng}; + +use crate::{RotateRight, Sha2BlockHasherSubairConfig}; + +/// Convert a word into a list of 8-bit limbs in little endian +pub fn word_into_u8_limbs(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_U8S) +} + +/// Convert a word into a list of 16-bit limbs in little endian +pub fn word_into_u16_limbs(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_U16S) +} + +/// Convert a word into a list of 1-bit limbs in little endian +pub fn word_into_bits(num: impl Into) -> Vec { + word_into_limbs::(num.into(), C::WORD_BITS) +} + +/// Convert a word into a list of limbs in little endian +pub fn word_into_limbs(num: C::Word, num_limbs: usize) -> Vec { + let limb_bits = std::mem::size_of::() * 8 / num_limbs; + (0..num_limbs) + .map(|i| { + let shifted = num >> (limb_bits * i); + let mask: C::Word = ((1u32 << limb_bits) - 1).into(); + let masked = shifted & mask; + masked.try_into().unwrap() + }) + .collect() +} + +/// Convert a u32 into a list of 1-bit limbs in little endian +pub fn u32_into_bits(num: u32) -> Vec { + let limb_bits = 32 / C::WORD_BITS; + (0..C::WORD_BITS) + .map(|i| (num >> (limb_bits * i)) & ((1 << limb_bits) - 1)) + .collect() +} + +/// Convert a list of limbs in little endian into a Word +pub fn le_limbs_into_word(limbs: &[u32]) -> C::Word { + let mut limbs = limbs.to_vec(); + limbs.reverse(); + be_limbs_into_word::(&limbs) +} + +/// Convert a list of limbs in big endian into a Word +pub fn be_limbs_into_word(limbs: &[u32]) -> C::Word { + let limb_bits = C::WORD_BITS / limbs.len(); + limbs.iter().fold(C::Word::from(0), |acc, &limb| { + (acc << limb_bits) | limb.into() + }) +} + +/// Convert a list of limbs in little endian into a u32 +pub fn limbs_into_u32(limbs: &[u32]) -> u32 { + let limb_bits = 32 / limbs.len(); + limbs + .iter() + .rev() + .fold(0, |acc, &limb| (acc << limb_bits) | limb) +} + +/// Rotates `bits` right by `n` bits, assumes `bits` is in little-endian +#[inline] +pub(crate) fn rotr(bits: &[impl Into + Clone], n: usize) -> Vec { + (0..bits.len()) + .map(|i| bits[(i + n) % bits.len()].clone().into()) + .collect() +} + +/// Shifts `bits` right by `n` bits, assumes `bits` is in little-endian +#[inline] +pub(crate) fn shr(bits: &[impl Into + Clone], n: usize) -> Vec { + (0..bits.len()) + .map(|i| { + if i + n < bits.len() { + bits[i + n].clone().into() + } else { + F::ZERO + } + }) + .collect() +} + +/// Computes x ^ y ^ z, where x, y, z are assumed to be boolean +#[inline] +pub(crate) fn xor_bit( + x: impl Into, + y: impl Into, + z: impl Into, +) -> F { + let (x, y, z) = (x.into(), y.into(), z.into()); + (x.clone() * y.clone() * z.clone()) + + (x.clone() * not::(y.clone()) * not::(z.clone())) + + (not::(x.clone()) * y.clone() * not::(z.clone())) + + (not::(x) * not::(y) * z) +} + +/// Computes x ^ y ^ z, where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn xor( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| xor_bit(x[i].clone(), y[i].clone(), z[i].clone())) + .collect() +} + +/// Choose function from the SHA spec +#[inline] +pub fn ch(x: C::Word, y: C::Word, z: C::Word) -> C::Word { + (x & y) ^ ((!x) & z) +} + +/// Computes Ch(x,y,z), where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn ch_field( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| select(x[i].clone(), y[i].clone(), z[i].clone())) + .collect() +} + +/// Majority function from the SHA spec +pub fn maj(x: C::Word, y: C::Word, z: C::Word) -> C::Word { + (x & y) ^ (x & z) ^ (y & z) +} + +/// Computes Maj(x,y,z), where x, y, z are [C::WORD_BITS] bit numbers +#[inline] +pub(crate) fn maj_field( + x: &[impl Into + Clone], + y: &[impl Into + Clone], + z: &[impl Into + Clone], +) -> Vec { + (0..x.len()) + .map(|i| { + let (x, y, z) = ( + x[i].clone().into(), + y[i].clone().into(), + z[i].clone().into(), + ); + x.clone() * y.clone() + x.clone() * z.clone() + y.clone() * z.clone() + - F::TWO * x * y * z + }) + .collect() +} + +/// Big sigma_0 function from the SHA spec +pub fn big_sig0(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) + } else { + x.rotate_right(28) ^ x.rotate_right(34) ^ x.rotate_right(39) + } +} + +/// Computes BigSigma0(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn big_sig0_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 2), &rotr::(x, 13), &rotr::(x, 22)) + } else { + xor(&rotr::(x, 28), &rotr::(x, 34), &rotr::(x, 39)) + } +} + +/// Big sigma_1 function from the SHA spec +pub fn big_sig1(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) + } else { + x.rotate_right(14) ^ x.rotate_right(18) ^ x.rotate_right(41) + } +} + +/// Computes BigSigma1(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn big_sig1_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 6), &rotr::(x, 11), &rotr::(x, 25)) + } else { + xor(&rotr::(x, 14), &rotr::(x, 18), &rotr::(x, 41)) + } +} + +/// Small sigma_0 function from the SHA spec +pub fn small_sig0(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) + } else { + x.rotate_right(1) ^ x.rotate_right(8) ^ (x >> 7) + } +} + +/// Computes SmallSigma0(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn small_sig0_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 7), &rotr::(x, 18), &shr::(x, 3)) + } else { + xor(&rotr::(x, 1), &rotr::(x, 8), &shr::(x, 7)) + } +} + +/// Small sigma_1 function from the SHA spec +pub fn small_sig1(x: C::Word) -> C::Word { + if C::WORD_BITS == 32 { + x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) + } else { + x.rotate_right(19) ^ x.rotate_right(61) ^ (x >> 6) + } +} + +/// Computes SmallSigma1(x), where x is a [C::WORD_BITS] bit number in little-endian +#[inline] +pub(crate) fn small_sig1_field( + x: &[impl Into + Clone], +) -> Vec { + if C::WORD_BITS == 32 { + xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) + } else { + xor(&rotr::(x, 19), &rotr::(x, 61), &shr::(x, 6)) + } +} + +/// Generate a random message of a given length +pub fn get_random_message(rng: &mut StdRng, len: usize) -> Vec { + let mut random_message: Vec = vec![0u8; len]; + rng.fill(&mut random_message[..]); + random_message +} + +/// Wrapper of `get_flag_pt` to get the flag pointer as an array +pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> Vec { + encoder.get_flag_pt(flag_idx) +} + +/// Constrain the addition of [C::WORD_BITS] bit words in 16-bit limbs +/// It takes in the terms some in bits some in 16-bit limbs, +/// the expected sum in bits and the carries +pub fn constraint_word_addition( + builder: &mut AB, + terms_bits: &[&[impl Into + Clone]], + terms_limb: &[&[impl Into + Clone]], + expected_sum: &[impl Into + Clone], + carries: &[impl Into + Clone], +) { + debug_assert!(terms_bits.iter().all(|x| x.len() == C::WORD_BITS)); + debug_assert!(terms_limb.iter().all(|x| x.len() == C::WORD_U16S)); + assert_eq!(expected_sum.len(), C::WORD_BITS); + assert_eq!(carries.len(), C::WORD_U16S); + + for i in 0..C::WORD_U16S { + let mut limb_sum = if i == 0 { + AB::Expr::ZERO + } else { + carries[i - 1].clone().into() + }; + for term in terms_bits { + limb_sum += compose::(&term[i * 16..(i + 1) * 16], 1); + } + for term in terms_limb { + limb_sum += term[i].clone().into(); + } + let expected_sum_limb = compose::(&expected_sum[i * 16..(i + 1) * 16], 1) + + carries[i].clone().into() * AB::Expr::from_canonical_u32(1 << 16); + builder.assert_eq(limb_sum, expected_sum_limb); + } +} + +pub fn set_arrayview_from_u32_slice( + arrayview: &mut ArrayViewMut, + data: impl IntoIterator, +) { + arrayview + .iter_mut() + .zip(data.into_iter().map(|x| F::from_canonical_u32(x))) + .for_each(|(x, y)| *x = y); +} + +pub fn set_arrayview_from_u8_slice( + arrayview: &mut ArrayViewMut, + data: impl IntoIterator, +) { + arrayview + .iter_mut() + .zip(data.into_iter().map(|x| F::from_canonical_u8(x))) + .for_each(|(x, y)| *x = y); +} diff --git a/crates/circuits/sha256-air/src/air.rs b/crates/circuits/sha256-air/src/air.rs deleted file mode 100644 index b27af6ffa9..0000000000 --- a/crates/circuits/sha256-air/src/air.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::{array, borrow::Borrow, cmp::max, iter::once}; - -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, - encoder::Encoder, - utils::{not, select}, - SubAir, -}; -use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, - p3_air::{AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra}, - p3_matrix::Matrix, -}; - -use super::{ - big_sig0_field, big_sig1_field, ch_field, compose, maj_field, small_sig0_field, - small_sig1_field, Sha256DigestCols, Sha256RoundCols, SHA256_DIGEST_WIDTH, SHA256_H, - SHA256_HASH_WORDS, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROUND_WIDTH, SHA256_WORD_BITS, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; -use crate::{constraint_word_addition, u32_into_u16s}; - -/// Expects the message to be padded to a multiple of 512 bits -#[derive(Clone, Debug)] -pub struct Sha256Air { - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - pub row_idx_encoder: Encoder, - /// Internal bus for self-interactions in this AIR. - bus: PermutationCheckBus, -} - -impl Sha256Air { - pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, self_bus_idx: BusIndex) -> Self { - Self { - bitwise_lookup_bus, - row_idx_encoder: Encoder::new(18, 2, false), - bus: PermutationCheckBus::new(self_bus_idx), - } - } -} - -impl BaseAir for Sha256Air { - fn width(&self) -> usize { - max( - Sha256RoundCols::::width(), - Sha256DigestCols::::width(), - ) - } -} - -impl SubAir for Sha256Air { - /// The start column for the sub-air to use - type AirContext<'a> - = usize - where - Self: 'a, - AB: 'a, - ::Var: 'a, - ::Expr: 'a; - - fn eval<'a>(&'a self, builder: &'a mut AB, start_col: Self::AirContext<'a>) - where - ::Var: 'a, - ::Expr: 'a, - { - self.eval_row(builder, start_col); - self.eval_transitions(builder, start_col); - } -} - -impl Sha256Air { - /// Implements the single row constraints (i.e. imposes constraints only on local) - /// Implements some sanity constraints on the row index, flags, and work variables - fn eval_row(&self, builder: &mut AB, start_col: usize) { - let main = builder.main(); - let local = main.row_slice(0); - - // Doesn't matter which column struct we use here as we are only interested in the common - // columns - let local_cols: &Sha256DigestCols = - local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - let flags = &local_cols.flags; - builder.assert_bool(flags.is_round_row); - builder.assert_bool(flags.is_first_4_rows); - builder.assert_bool(flags.is_digest_row); - builder.assert_bool(flags.is_round_row + flags.is_digest_row); - builder.assert_bool(flags.is_last_block); - - self.row_idx_encoder - .eval(builder, &local_cols.flags.row_idx); - builder.assert_one( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=17), - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=3), - flags.is_first_4_rows, - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag_range::(&local_cols.flags.row_idx, 0..=15), - flags.is_round_row, - ); - builder.assert_eq( - self.row_idx_encoder - .contains_flag::(&local_cols.flags.row_idx, &[16]), - flags.is_digest_row, - ); - // If padding row we want the row_idx to be 17 - builder.assert_eq( - self.row_idx_encoder - .contains_flag::(&local_cols.flags.row_idx, &[17]), - flags.is_padding_row(), - ); - - // Constrain a, e, being composed of bits: we make sure a and e are always in the same place - // in the trace matrix Note: this has to be true for every row, even padding rows - for i in 0..SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_WORD_BITS { - builder.assert_bool(local_cols.hash.a[i][j]); - builder.assert_bool(local_cols.hash.e[i][j]); - } - } - } - - /// Implements constraints for a digest row that ensure proper state transitions between blocks - /// This validates that: - /// The work variables are correctly initialized for the next message block - /// For the last message block, the initial state matches SHA256_H constants - fn eval_digest_row( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256DigestCols, - ) { - // Check that if this is the last row of a message or an inpadding row, the hash should be - // the [SHA256_H] - for i in 0..SHA256_ROUNDS_PER_ROW { - let a = next.hash.a[i].map(|x| x.into()); - let e = next.hash.e[i].map(|x| x.into()); - for j in 0..SHA256_WORD_U16S { - let a_limb = compose::(&a[j * 16..(j + 1) * 16], 1); - let e_limb = compose::(&e[j * 16..(j + 1) * 16], 1); - - // If it is a padding row or the last row of a message, the `hash` should be the - // [SHA256_H] - builder - .when( - next.flags.is_padding_row() - + next.flags.is_last_block * next.flags.is_digest_row, - ) - .assert_eq( - a_limb, - AB::Expr::from_canonical_u32( - u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i - 1])[j], - ), - ); - - builder - .when( - next.flags.is_padding_row() - + next.flags.is_last_block * next.flags.is_digest_row, - ) - .assert_eq( - e_limb, - AB::Expr::from_canonical_u32( - u32_into_u16s(SHA256_H[SHA256_ROUNDS_PER_ROW - i + 3])[j], - ), - ); - } - } - - // Check if last row of a non-last block, the `hash` should be equal to the final hash of - // the current block - for i in 0..SHA256_ROUNDS_PER_ROW { - let prev_a = next.hash.a[i].map(|x| x.into()); - let prev_e = next.hash.e[i].map(|x| x.into()); - let cur_a = next.final_hash[SHA256_ROUNDS_PER_ROW - i - 1].map(|x| x.into()); - - let cur_e = next.final_hash[SHA256_ROUNDS_PER_ROW - i + 3].map(|x| x.into()); - for j in 0..SHA256_WORD_U8S { - let prev_a_limb = compose::(&prev_a[j * 8..(j + 1) * 8], 1); - let prev_e_limb = compose::(&prev_e[j * 8..(j + 1) * 8], 1); - - builder - .when(not(next.flags.is_last_block) * next.flags.is_digest_row) - .assert_eq(prev_a_limb, cur_a[j].clone()); - - builder - .when(not(next.flags.is_last_block) * next.flags.is_digest_row) - .assert_eq(prev_e_limb, cur_e[j].clone()); - } - } - - // Assert that the previous hash + work vars == final hash. - // That is, `next.prev_hash[i] + local.work_vars[i] == next.final_hash[i]` - // where addition is done modulo 2^32 - for i in 0..SHA256_HASH_WORDS { - let mut carry = AB::Expr::ZERO; - for j in 0..SHA256_WORD_U16S { - let work_var_limb = if i < SHA256_ROUNDS_PER_ROW { - compose::( - &local.work_vars.a[SHA256_ROUNDS_PER_ROW - 1 - i][j * 16..(j + 1) * 16], - 1, - ) - } else { - compose::( - &local.work_vars.e[SHA256_ROUNDS_PER_ROW + 3 - i][j * 16..(j + 1) * 16], - 1, - ) - }; - let final_hash_limb = - compose::(&next.final_hash[i][j * 2..(j + 1) * 2], 8); - - carry = AB::Expr::from(AB::F::from_canonical_u32(1 << 16).inverse()) - * (next.prev_hash[i][j] + work_var_limb + carry - final_hash_limb); - builder - .when(next.flags.is_digest_row) - .assert_bool(carry.clone()); - } - // constrain the final hash limbs two at a time since we can do two checks per - // interaction - for chunk in next.final_hash[i].chunks(2) { - self.bitwise_lookup_bus - .send_range(chunk[0], chunk[1]) - .eval(builder, next.flags.is_digest_row); - } - } - } - - fn eval_transitions(&self, builder: &mut AB, start_col: usize) { - let main = builder.main(); - let local = main.row_slice(0); - let next = main.row_slice(1); - - // Doesn't matter what column structs we use here - let local_cols: &Sha256RoundCols = - local[start_col..start_col + SHA256_ROUND_WIDTH].borrow(); - let next_cols: &Sha256RoundCols = - next[start_col..start_col + SHA256_ROUND_WIDTH].borrow(); - - let local_is_padding_row = local_cols.flags.is_padding_row(); - // Note that there will always be a padding row in the trace since the unpadded height is a - // multiple of 17. So the next row is padding iff the current block is the last - // block in the trace. - let next_is_padding_row = next_cols.flags.is_padding_row(); - - // We check that the very last block has `is_last_block` set to true, which guarantees that - // there is at least one complete message. If other digest rows have `is_last_block` set to - // true, then the trace will be interpreted as containing multiple messages. - builder - .when(next_is_padding_row.clone()) - .when(local_cols.flags.is_digest_row) - .assert_one(local_cols.flags.is_last_block); - // If we are in a round row, the next row cannot be a padding row - builder - .when(local_cols.flags.is_round_row) - .assert_zero(next_is_padding_row.clone()); - // The first row must be a round row - builder - .when_first_row() - .assert_one(local_cols.flags.is_round_row); - // If we are in a padding row, the next row must also be a padding row - builder - .when_transition() - .when(local_is_padding_row.clone()) - .assert_one(next_is_padding_row.clone()); - // If we are in a digest row, the next row cannot be a digest row - builder - .when(local_cols.flags.is_digest_row) - .assert_zero(next_cols.flags.is_digest_row); - // Constrain how much the row index changes by - // round->round: 1 - // round->digest: 1 - // digest->round: -16 - // digest->padding: 1 - // padding->padding: 0 - // Other transitions are not allowed by the above constraints - let delta = local_cols.flags.is_round_row * AB::Expr::ONE - + local_cols.flags.is_digest_row - * next_cols.flags.is_round_row - * AB::Expr::from_canonical_u32(16) - * AB::Expr::NEG_ONE - + local_cols.flags.is_digest_row * next_is_padding_row.clone() * AB::Expr::ONE; - - let local_row_idx = self.row_idx_encoder.flag_with_val::( - &local_cols.flags.row_idx, - &(0..18).map(|i| (i, i)).collect::>(), - ); - let next_row_idx = self.row_idx_encoder.flag_with_val::( - &next_cols.flags.row_idx, - &(0..18).map(|i| (i, i)).collect::>(), - ); - - builder - .when_transition() - .assert_eq(local_row_idx.clone() + delta, next_row_idx.clone()); - builder.when_first_row().assert_zero(local_row_idx); - - // Constrain the global block index - // We set the global block index to 0 for padding rows - // Starting with 1 so it is not the same as the padding rows - - // Global block index is 1 on first row - builder - .when_first_row() - .assert_one(local_cols.flags.global_block_idx); - - // Global block index is constant on all rows in a block - builder.when(local_cols.flags.is_round_row).assert_eq( - local_cols.flags.global_block_idx, - next_cols.flags.global_block_idx, - ); - // Global block index increases by 1 between blocks - builder - .when_transition() - .when(local_cols.flags.is_digest_row) - .when(next_cols.flags.is_round_row) - .assert_eq( - local_cols.flags.global_block_idx + AB::Expr::ONE, - next_cols.flags.global_block_idx, - ); - // Global block index is 0 on padding rows - builder - .when(local_is_padding_row.clone()) - .assert_zero(local_cols.flags.global_block_idx); - - // Constrain the local block index - // We set the local block index to 0 for padding rows - - // Local block index is constant on all rows in a block - // and its value on padding rows is equal to its value on the first block - builder.when(not(local_cols.flags.is_digest_row)).assert_eq( - local_cols.flags.local_block_idx, - next_cols.flags.local_block_idx, - ); - // Local block index increases by 1 between blocks in the same message - builder - .when(local_cols.flags.is_digest_row) - .when(not(local_cols.flags.is_last_block)) - .assert_eq( - local_cols.flags.local_block_idx + AB::Expr::ONE, - next_cols.flags.local_block_idx, - ); - // Local block index is 0 on padding rows - // Combined with the above, this means that the local block index is 0 in the first block - builder - .when(local_cols.flags.is_digest_row) - .when(local_cols.flags.is_last_block) - .assert_zero(next_cols.flags.local_block_idx); - - self.eval_message_schedule::(builder, local_cols, next_cols); - self.eval_work_vars::(builder, local_cols, next_cols); - let next_cols: &Sha256DigestCols = - next[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - self.eval_digest_row(builder, local_cols, next_cols); - let local_cols: &Sha256DigestCols = - local[start_col..start_col + SHA256_DIGEST_WIDTH].borrow(); - self.eval_prev_hash::(builder, local_cols, next_is_padding_row); - } - - /// Constrains that the next block's `prev_hash` is equal to the current block's `hash` - /// Note: the constraining is done by interactions with the chip itself on every digest row - fn eval_prev_hash( - &self, - builder: &mut AB, - local: &Sha256DigestCols, - is_last_block_of_trace: AB::Expr, /* note this indicates the last block of the trace, - * not the last block of the message */ - ) { - // Constrain that next block's `prev_hash` is equal to the current block's `hash` - let composed_hash: [[::Expr; SHA256_WORD_U16S]; SHA256_HASH_WORDS] = - array::from_fn(|i| { - let hash_bits = if i < SHA256_ROUNDS_PER_ROW { - local.hash.a[SHA256_ROUNDS_PER_ROW - 1 - i].map(|x| x.into()) - } else { - local.hash.e[SHA256_ROUNDS_PER_ROW + 3 - i].map(|x| x.into()) - }; - array::from_fn(|j| compose::(&hash_bits[j * 16..(j + 1) * 16], 1)) - }); - // Need to handle the case if this is the very last block of the trace matrix - let next_global_block_idx = select( - is_last_block_of_trace, - AB::Expr::ONE, - local.flags.global_block_idx + AB::Expr::ONE, - ); - // The following interactions constrain certain values from block to block - self.bus.send( - builder, - composed_hash - .into_iter() - .flatten() - .chain(once(next_global_block_idx)), - local.flags.is_digest_row, - ); - - self.bus.receive( - builder, - local - .prev_hash - .into_iter() - .flatten() - .map(|x| x.into()) - .chain(once(local.flags.global_block_idx.into())), - local.flags.is_digest_row, - ); - } - - /// Constrain the message schedule additions for `next` row - /// Note: For every addition we need to constrain the following for each of [SHA256_WORD_U16S] - /// limbs sig_1(w_{t-2})[i] + w_{t-7}[i] + sig_0(w_{t-15})[i] + w_{t-16}[i] + - /// carry_w[t][i-1] - carry_w[t][i] * 2^16 - w_t[i] == 0 Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] - fn eval_message_schedule( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256RoundCols, - ) { - // This `w` array contains 8 message schedule words - w_{idx}, ..., w_{idx+7} for some idx - let w = [local.message_schedule.w, next.message_schedule.w].concat(); - - // Constrain `w_3` for `next` row - for i in 0..SHA256_ROUNDS_PER_ROW - 1 { - // here we constrain the w_3 of the i_th word of the next row - // w_3 of next is w[i+4-3] = w[i+1] - let w_3 = w[i + 1].map(|x| x.into()); - let expected_w_3 = next.schedule_helper.w_3[i]; - for j in 0..SHA256_WORD_U16S { - let w_3_limb = compose::(&w_3[j * 16..(j + 1) * 16], 1); - builder - .when(local.flags.is_round_row) - .assert_eq(w_3_limb, expected_w_3[j].into()); - } - } - - // Constrain intermed for `next` row - // We will only constrain intermed_12 for rows [3, 14], and let it be unconstrained for - // other rows Other rows should put the needed value in intermed_12 to make the - // below summation constraint hold - let is_row_3_14 = self - .row_idx_encoder - .contains_flag_range::(&next.flags.row_idx, 3..=14); - // We will only constrain intermed_8 for rows [2, 13], and let it unconstrained for other - // rows - let is_row_2_13 = self - .row_idx_encoder - .contains_flag_range::(&next.flags.row_idx, 2..=13); - for i in 0..SHA256_ROUNDS_PER_ROW { - // w_idx - let w_idx = w[i].map(|x| x.into()); - // sig_0(w_{idx+1}) - let sig_w = small_sig0_field::(&w[i + 1]); - for j in 0..SHA256_WORD_U16S { - let w_idx_limb = compose::(&w_idx[j * 16..(j + 1) * 16], 1); - let sig_w_limb = compose::(&sig_w[j * 16..(j + 1) * 16], 1); - - // We would like to constrain this only on rows 0..16, but we can't do a conditional - // check because the degree is already 3. So we must fill in - // `intermed_4` with dummy values on rows 0 and 16 to ensure the constraint holds on - // these rows. - builder.when_transition().assert_eq( - next.schedule_helper.intermed_4[i][j], - w_idx_limb + sig_w_limb, - ); - - builder.when(is_row_2_13.clone()).assert_eq( - next.schedule_helper.intermed_8[i][j], - local.schedule_helper.intermed_4[i][j], - ); - - builder.when(is_row_3_14.clone()).assert_eq( - next.schedule_helper.intermed_12[i][j], - local.schedule_helper.intermed_8[i][j], - ); - } - } - - // Constrain the message schedule additions for `next` row - for i in 0..SHA256_ROUNDS_PER_ROW { - // Note, here by w_{t} we mean the i_th word of the `next` row - // w_{t-7} - let w_7 = if i < 3 { - local.schedule_helper.w_3[i].map(|x| x.into()) - } else { - let w_3 = w[i - 3].map(|x| x.into()); - array::from_fn(|j| compose::(&w_3[j * 16..(j + 1) * 16], 1)) - }; - // sig_0(w_{t-15}) + w_{t-16} - let intermed_16 = local.schedule_helper.intermed_12[i].map(|x| x.into()); - - let carries = array::from_fn(|j| { - next.message_schedule.carry_or_buffer[i][j * 2] - + AB::Expr::TWO * next.message_schedule.carry_or_buffer[i][j * 2 + 1] - }); - - // Constrain `W_{idx} = sig_1(W_{idx-2}) + W_{idx-7} + sig_0(W_{idx-15}) + W_{idx-16}` - // We would like to constrain this only on rows 4..16, but we can't do a conditional - // check because the degree of sum is already 3 So we must fill in - // `intermed_12` with dummy values on rows 0..3 and 15 and 16 to ensure the constraint - // holds on rows 0..4 and 16. Note that the dummy value goes in the previous - // row to make the current row's constraint hold. - constraint_word_addition( - // Note: here we can't do a conditional check because the degree of sum is already - // 3 - &mut builder.when_transition(), - &[&small_sig1_field::(&w[i + 2])], - &[&w_7, &intermed_16], - &w[i + 4], - &carries, - ); - - for j in 0..SHA256_WORD_U16S { - // When on rows 4..16 message schedule carries should be 0 or 1 - let is_row_4_15 = next.flags.is_round_row - next.flags.is_first_4_rows; - builder - .when(is_row_4_15.clone()) - .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2]); - builder - .when(is_row_4_15) - .assert_bool(next.message_schedule.carry_or_buffer[i][j * 2 + 1]); - } - // Constrain w being composed of bits - for j in 0..SHA256_WORD_BITS { - builder - .when(next.flags.is_round_row) - .assert_bool(next.message_schedule.w[i][j]); - } - } - } - - /// Constrain the work vars on `next` row according to the sha256 documentation - /// Refer to [https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf] - fn eval_work_vars( - &self, - builder: &mut AB, - local: &Sha256RoundCols, - next: &Sha256RoundCols, - ) { - let a = [local.work_vars.a, next.work_vars.a].concat(); - let e = [local.work_vars.e, next.work_vars.e].concat(); - for i in 0..SHA256_ROUNDS_PER_ROW { - for j in 0..SHA256_WORD_U16S { - // Although we need carry_a <= 6 and carry_e <= 5, constraining carry_a, carry_e in - // [0, 2^8) is enough to prevent overflow and ensure the soundness - // of the addition we want to check - self.bitwise_lookup_bus - .send_range(local.work_vars.carry_a[i][j], local.work_vars.carry_e[i][j]) - .eval(builder, local.flags.is_round_row); - } - - let w_limbs = array::from_fn(|j| { - compose::(&next.message_schedule.w[i][j * 16..(j + 1) * 16], 1) - * next.flags.is_round_row - }); - let k_limbs = array::from_fn(|j| { - self.row_idx_encoder.flag_with_val::( - &next.flags.row_idx, - &(0..16) - .map(|rw_idx| { - ( - rw_idx, - u32_into_u16s(SHA256_K[rw_idx * SHA256_ROUNDS_PER_ROW + i])[j] - as usize, - ) - }) - .collect::>(), - ) - }); - - // Constrain `a = h + sig_1(e) + ch(e, f, g) + K + W + sig_0(a) + Maj(a, b, c)` - // We have to enforce this constraint on all rows since the degree of the constraint is - // already 3. So, we must fill in `carry_a` with dummy values on digest rows - // to ensure the constraint holds. - constraint_word_addition( - builder, - &[ - &e[i].map(|x| x.into()), // previous `h` - &big_sig1_field::(&e[i + 3]), // sig_1 of previous `e` - &ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous - * `e`, `f`, `g` */ - &big_sig0_field::(&a[i + 3]), // sig_0 of previous `a` - &maj_field::(&a[i + 3], &a[i + 2], &a[i + 1]), /* Maj of previous - * a, b, c */ - ], - &[&w_limbs, &k_limbs], // K and W - &a[i + 4], // new `a` - &next.work_vars.carry_a[i], // carries of addition - ); - - // Constrain `e = d + h + sig_1(e) + ch(e, f, g) + K + W` - // We have to enforce this constraint on all rows since the degree of the constraint is - // already 3. So, we must fill in `carry_e` with dummy values on digest rows - // to ensure the constraint holds. - constraint_word_addition( - builder, - &[ - &a[i].map(|x| x.into()), // previous `d` - &e[i].map(|x| x.into()), // previous `h` - &big_sig1_field::(&e[i + 3]), /* sig_1 of previous - * `e` */ - &ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]), /* Ch of previous - * `e`, `f`, `g` */ - ], - &[&w_limbs, &k_limbs], // K and W - &e[i + 4], // new `e` - &next.work_vars.carry_e[i], // carries of addition - ); - } - } -} diff --git a/crates/circuits/sha256-air/src/columns.rs b/crates/circuits/sha256-air/src/columns.rs deleted file mode 100644 index 1c735394c3..0000000000 --- a/crates/circuits/sha256-air/src/columns.rs +++ /dev/null @@ -1,140 +0,0 @@ -//! WARNING: the order of fields in the structs is important, do not change it - -use openvm_circuit_primitives::{utils::not, AlignedBorrow}; -use openvm_stark_backend::p3_field::FieldAlgebra; - -use super::{ - SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, SHA256_ROW_VAR_CNT, SHA256_WORD_BITS, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -/// In each SHA256 block: -/// - First 16 rows use Sha256RoundCols -/// - Final row uses Sha256DigestCols -/// -/// Note that for soundness, we require that there is always a padding row after the last digest row -/// in the trace. Right now, this is true because the unpadded height is a multiple of 17, and thus -/// not a power of 2. -/// -/// Sha256RoundCols and Sha256DigestCols share the same first 3 fields: -/// - flags -/// - work_vars/hash (same type, different name) -/// - schedule_helper -/// -/// This design allows for: -/// 1. Common constraints to work on either struct type by accessing these shared fields -/// 2. Specific constraints to use the appropriate struct, with flags helping to do conditional -/// constraints -/// -/// Note that the `Sha256WorkVarsCols` field it is used for different purposes in the two structs. -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256RoundCols { - pub flags: Sha256FlagsCols, - /// Stores the current state of the working variables - pub work_vars: Sha256WorkVarsCols, - pub schedule_helper: Sha256MessageHelperCols, - pub message_schedule: Sha256MessageScheduleCols, -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256DigestCols { - pub flags: Sha256FlagsCols, - /// Will serve as previous hash values for the next block. - /// - on non-last blocks, this is the final hash of the current block - /// - on last blocks, this is the initial state constants, SHA256_H. - /// The work variables constraints are applied on all rows, so `carry_a` and `carry_e` - /// must be filled in with dummy values to ensure these constraints hold. - pub hash: Sha256WorkVarsCols, - pub schedule_helper: Sha256MessageHelperCols, - /// The actual final hash values of the given block - /// Note: the above `hash` will be equal to `final_hash` unless we are on the last block - pub final_hash: [[T; SHA256_WORD_U8S]; SHA256_HASH_WORDS], - /// The final hash of the previous block - /// Note: will be constrained using interactions with the chip itself - pub prev_hash: [[T; SHA256_WORD_U16S]; SHA256_HASH_WORDS], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256MessageScheduleCols { - /// The message schedule words as 32-bit integers - /// The first 16 words will be the message data - pub w: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - /// Will be message schedule carries for rows 4..16 and a buffer for rows 0..4 to be used - /// freely by wrapper chips Note: carries are 2 bit numbers represented using 2 cells as - /// individual bits - pub carry_or_buffer: [[T; SHA256_WORD_U8S]; SHA256_ROUNDS_PER_ROW], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256WorkVarsCols { - /// `a` and `e` after each iteration as 32-bits - pub a: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - pub e: [[T; SHA256_WORD_BITS]; SHA256_ROUNDS_PER_ROW], - /// The carry's used for addition during each iteration when computing `a` and `e` - pub carry_a: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub carry_e: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], -} - -/// These are the columns that are used to help with the message schedule additions -/// Note: these need to be correctly assigned for every row even on padding rows -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256MessageHelperCols { - /// The following are used to move data forward to constrain the message schedule additions - /// The value of `w` (message schedule word) from 3 rounds ago - /// In general, `w_i` means `w` from `i` rounds ago - pub w_3: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW - 1], - /// Here intermediate(i) = w_i + sig_0(w_{i+1}) - /// Intermed_t represents the intermediate t rounds ago - /// This is needed to constrain the message schedule, since we can only constrain on two rows - /// at a time - pub intermed_4: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub intermed_8: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], - pub intermed_12: [[T; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW], -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256FlagsCols { - /// A flag that indicates if the current row is among the first 16 rows of a block. - pub is_round_row: T, - /// A flag that indicates if the current row is among the first 4 rows of a block. - pub is_first_4_rows: T, - /// A flag that indicates if the current row is the last (17th) row of a block. - pub is_digest_row: T, - // A flag that indicates if the current row is the last block of the message. - // This flag is only used in digest rows. - pub is_last_block: T, - /// We will encode the row index [0..17) using 5 cells - pub row_idx: [T; SHA256_ROW_VAR_CNT], - /// The index of the current block in the trace starting at 1. - /// Set to 0 on padding rows. - pub global_block_idx: T, - /// The index of the current block in the current message starting at 0. - /// Resets after every message. - /// Set to 0 on padding rows. - pub local_block_idx: T, -} - -impl> Sha256FlagsCols { - // This refers to the padding rows that are added to the air to make the trace length a power of - // 2. Not to be confused with the padding added to messages as part of the SHA hash - // function. - pub fn is_not_padding_row(&self) -> O { - self.is_round_row + self.is_digest_row - } - - // This refers to the padding rows that are added to the air to make the trace length a power of - // 2. Not to be confused with the padding added to messages as part of the SHA hash - // function. - pub fn is_padding_row(&self) -> O - where - O: FieldAlgebra, - { - not(self.is_not_padding_row()) - } -} diff --git a/crates/circuits/sha256-air/src/tests.rs b/crates/circuits/sha256-air/src/tests.rs deleted file mode 100644 index 7ad0229185..0000000000 --- a/crates/circuits/sha256-air/src/tests.rs +++ /dev/null @@ -1,163 +0,0 @@ -use std::{array, borrow::BorrowMut, sync::Arc}; - -use openvm_circuit::arch::{ - instructions::riscv::RV32_CELL_BITS, - testing::{VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::{ - BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, - SharedBitwiseOperationLookupChip, - }, - SubAir, -}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - interaction::{BusIndex, InteractionBuilder}, - p3_air::{Air, BaseAir}, - p3_field::{Field, FieldAlgebra, PrimeField32}, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - prover::{cpu::CpuBackend, types::AirProvingContext}, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, - utils::disable_debug_builder, - verifier::VerificationError, - AirRef, Chip, -}; -use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::Rng; - -use crate::{ - Sha256Air, Sha256DigestCols, Sha256FillerHelper, SHA256_BLOCK_U8S, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_WIDTH, SHA256_WORD_U8S, -}; - -// A wrapper AIR purely for testing purposes -#[derive(Clone, Debug)] -pub struct Sha256TestAir { - pub sub_air: Sha256Air, -} - -impl BaseAirWithPublicValues for Sha256TestAir {} -impl PartitionedBaseAir for Sha256TestAir {} -impl BaseAir for Sha256TestAir { - fn width(&self) -> usize { - >::width(&self.sub_air) - } -} - -impl Air for Sha256TestAir { - fn eval(&self, builder: &mut AB) { - self.sub_air.eval(builder, 0); - } -} - -const SELF_BUS_IDX: BusIndex = 28; -type F = BabyBear; -type RecordType = Vec<([u8; SHA256_BLOCK_U8S], bool)>; - -// A wrapper Chip purely for testing purposes -pub struct Sha256TestChip { - pub step: Sha256FillerHelper, - pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, -} - -impl Chip> for Sha256TestChip -where - Val: PrimeField32, -{ - fn generate_proving_ctx(&self, records: RecordType) -> AirProvingContext> { - let trace = crate::generate_trace::>( - &self.step, - self.bitwise_lookup_chip.as_ref(), - SHA256_WIDTH, - records, - ); - AirProvingContext::simple_no_pis(Arc::new(trace)) - } -} - -#[allow(clippy::type_complexity)] -fn create_air_with_air_ctx() -> ( - (AirRef, AirProvingContext>), - ( - BitwiseOperationLookupAir, - SharedBitwiseOperationLookupChip, - ), -) -where - Val: PrimeField32, -{ - let mut rng = create_seeded_rng(); - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - let len = rng.gen_range(1..100); - let random_records: Vec<_> = (0..len) - .map(|i| { - ( - array::from_fn(|_| rng.gen::()), - rng.gen::() || i == len - 1, - ) - }) - .collect(); - - let air = Sha256TestAir { - sub_air: Sha256Air::new(bitwise_bus, SELF_BUS_IDX), - }; - let chip = Sha256TestChip { - step: Sha256FillerHelper::new(), - bitwise_lookup_chip: bitwise_chip.clone(), - }; - let air_ctx = chip.generate_proving_ctx(random_records); - - ((Arc::new(air), air_ctx), (bitwise_chip.air, bitwise_chip)) -} - -#[test] -fn rand_sha256_test() { - let tester = VmChipTestBuilder::default(); - let (air_ctx, bitwise) = create_air_with_air_ctx(); - let tester = tester - .build() - .load_air_proving_ctx(air_ctx) - .load_periphery(bitwise) - .finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn negative_sha256_test_bad_final_hash() { - let tester = VmChipTestBuilder::default(); - let ((air, mut air_ctx), bitwise) = create_air_with_air_ctx(); - - // Set the final_hash to all zeros - let modify_trace = |trace: &mut RowMajorMatrix| { - trace.row_chunks_exact_mut(1).for_each(|row| { - let mut row_slice = row.row_slice(0).to_vec(); - let cols: &mut Sha256DigestCols = row_slice[..SHA256_DIGEST_WIDTH].borrow_mut(); - if cols.flags.is_last_block.is_one() && cols.flags.is_digest_row.is_one() { - for i in 0..SHA256_HASH_WORDS { - for j in 0..SHA256_WORD_U8S { - cols.final_hash[i][j] = F::ZERO; - } - } - row.values.copy_from_slice(&row_slice); - } - }); - }; - - // Modify the air_ctx - let trace = Option::take(&mut air_ctx.common_main).unwrap(); - let mut trace = Arc::into_inner(trace).unwrap(); - modify_trace(&mut trace); - air_ctx.common_main = Some(Arc::new(trace)); - - disable_debug_builder(); - let tester = tester - .build() - .load_air_proving_ctx((air, air_ctx)) - .load_periphery(bitwise) - .finalize(); - tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); -} diff --git a/crates/circuits/sha256-air/src/trace.rs b/crates/circuits/sha256-air/src/trace.rs deleted file mode 100644 index 8cbaebbc55..0000000000 --- a/crates/circuits/sha256-air/src/trace.rs +++ /dev/null @@ -1,558 +0,0 @@ -use std::{array, borrow::BorrowMut, ops::Range}; - -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupChip, encoder::Encoder, - utils::next_power_of_two_or_zero, -}; -use openvm_stark_backend::{ - p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix, p3_maybe_rayon::prelude::*, -}; -use sha2::{compress256, digest::generic_array::GenericArray}; - -use super::{ - big_sig0_field, big_sig1_field, ch_field, columns::Sha256RoundCols, compose, get_flag_pt_array, - maj_field, small_sig0_field, small_sig1_field, SHA256_BLOCK_WORDS, SHA256_DIGEST_WIDTH, - SHA256_HASH_WORDS, SHA256_ROUND_WIDTH, -}; -use crate::{ - big_sig0, big_sig1, ch, columns::Sha256DigestCols, limbs_into_u32, maj, small_sig0, small_sig1, - u32_into_bits_field, u32_into_u16s, SHA256_BLOCK_U8S, SHA256_H, SHA256_INVALID_CARRY_A, - SHA256_INVALID_CARRY_E, SHA256_K, SHA256_ROUNDS_PER_ROW, SHA256_ROWS_PER_BLOCK, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; - -/// A helper struct for the SHA256 trace generation. -/// Also, separates the inner AIR from the trace generation. -pub struct Sha256FillerHelper { - pub row_idx_encoder: Encoder, -} - -impl Default for Sha256FillerHelper { - fn default() -> Self { - Self::new() - } -} - -/// The trace generation of SHA256 should be done in two passes. -/// The first pass should do `get_block_trace` for every block and generate the invalid rows through -/// `get_default_row` The second pass should go through all the blocks and call -/// `generate_missing_cells` -impl Sha256FillerHelper { - pub fn new() -> Self { - Self { - row_idx_encoder: Encoder::new(18, 2, false), - } - } - /// This function takes the input_message (padding not handled), the previous hash, - /// and returns the new hash after processing the block input - pub fn get_block_hash( - prev_hash: &[u32; SHA256_HASH_WORDS], - input: [u8; SHA256_BLOCK_U8S], - ) -> [u32; SHA256_HASH_WORDS] { - let mut new_hash = *prev_hash; - let input_array = [GenericArray::from(input)]; - compress256(&mut new_hash, &input_array); - new_hash - } - - /// This function takes a 512-bit chunk of the input message (padding not handled), the previous - /// hash, a flag indicating if it's the last block, the global block index, the local block - /// index, and the buffer values that will be put in rows 0..4. - /// Will populate the given `trace` with the trace of the block, where the width of the trace is - /// `trace_width` and the starting column for the `Sha256Air` is `trace_start_col`. - /// **Note**: this function only generates some of the required trace. Another pass is required, - /// refer to [`Self::generate_missing_cells`] for details. - #[allow(clippy::too_many_arguments)] - pub fn generate_block_trace( - &self, - trace: &mut [F], - trace_width: usize, - trace_start_col: usize, - input: &[u32; SHA256_BLOCK_WORDS], - bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, - prev_hash: &[u32; SHA256_HASH_WORDS], - is_last_block: bool, - global_block_idx: u32, - local_block_idx: u32, - ) { - #[cfg(debug_assertions)] - { - assert!(trace.len() == trace_width * SHA256_ROWS_PER_BLOCK); - assert!(trace_start_col + super::SHA256_WIDTH <= trace_width); - if local_block_idx == 0 { - assert!(*prev_hash == SHA256_H); - } - } - let get_range = |start: usize, len: usize| -> Range { start..start + len }; - let mut message_schedule = [0u32; 64]; - message_schedule[..input.len()].copy_from_slice(input); - let mut work_vars = *prev_hash; - for (i, row) in trace.chunks_exact_mut(trace_width).enumerate() { - // doing the 64 rounds in 16 rows - if i < 16 { - let cols: &mut Sha256RoundCols = - row[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - cols.flags.is_round_row = F::ONE; - cols.flags.is_first_4_rows = if i < 4 { F::ONE } else { F::ZERO }; - cols.flags.is_digest_row = F::ZERO; - cols.flags.is_last_block = F::from_bool(is_last_block); - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, i).map(F::from_canonical_u32); - cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); - cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); - - // W_idx = M_idx - if i < 4 { - for j in 0..SHA256_ROUNDS_PER_ROW { - cols.message_schedule.w[j] = - u32_into_bits_field::(input[i * SHA256_ROUNDS_PER_ROW + j]); - } - } - // W_idx = SIG1(W_{idx-2}) + W_{idx-7} + SIG0(W_{idx-15}) + W_{idx-16} - else { - for j in 0..SHA256_ROUNDS_PER_ROW { - let idx = i * SHA256_ROUNDS_PER_ROW + j; - let nums: [u32; 4] = [ - small_sig1(message_schedule[idx - 2]), - message_schedule[idx - 7], - small_sig0(message_schedule[idx - 15]), - message_schedule[idx - 16], - ]; - let w: u32 = nums.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - cols.message_schedule.w[j] = u32_into_bits_field::(w); - - let nums_limbs = nums.map(u32_into_u16s); - let w_limbs = u32_into_u16s(w); - - // fill in the carrys - for k in 0..SHA256_WORD_U16S { - let mut sum = nums_limbs.iter().fold(0, |acc, num| acc + num[k]); - if k > 0 { - sum += (cols.message_schedule.carry_or_buffer[j][k * 2 - 2] - + F::TWO * cols.message_schedule.carry_or_buffer[j][k * 2 - 1]) - .as_canonical_u32(); - } - let carry = (sum - w_limbs[k]) >> 16; - cols.message_schedule.carry_or_buffer[j][k * 2] = - F::from_canonical_u32(carry & 1); - cols.message_schedule.carry_or_buffer[j][k * 2 + 1] = - F::from_canonical_u32(carry >> 1); - } - // update the message schedule - message_schedule[idx] = w; - } - } - // fill in the work variables - for j in 0..SHA256_ROUNDS_PER_ROW { - // t1 = h + SIG1(e) + ch(e, f, g) + K_idx + W_idx - let t1 = [ - work_vars[7], - big_sig1(work_vars[4]), - ch(work_vars[4], work_vars[5], work_vars[6]), - SHA256_K[i * SHA256_ROUNDS_PER_ROW + j], - limbs_into_u32(cols.message_schedule.w[j].map(|f| f.as_canonical_u32())), - ]; - let t1_sum: u32 = t1.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - - // t2 = SIG0(a) + maj(a, b, c) - let t2 = [ - big_sig0(work_vars[0]), - maj(work_vars[0], work_vars[1], work_vars[2]), - ]; - - let t2_sum: u32 = t2.iter().fold(0, |acc, &num| acc.wrapping_add(num)); - - // e = d + t1 - let e = work_vars[3].wrapping_add(t1_sum); - cols.work_vars.e[j] = u32_into_bits_field::(e); - let e_limbs = u32_into_u16s(e); - // a = t1 + t2 - let a = t1_sum.wrapping_add(t2_sum); - cols.work_vars.a[j] = u32_into_bits_field::(a); - let a_limbs = u32_into_u16s(a); - // fill in the carrys - for k in 0..SHA256_WORD_U16S { - let t1_limb = t1.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); - let t2_limb = t2.iter().fold(0, |acc, &num| acc + u32_into_u16s(num)[k]); - - let mut e_limb = t1_limb + u32_into_u16s(work_vars[3])[k]; - let mut a_limb = t1_limb + t2_limb; - if k > 0 { - a_limb += cols.work_vars.carry_a[j][k - 1].as_canonical_u32(); - e_limb += cols.work_vars.carry_e[j][k - 1].as_canonical_u32(); - } - let carry_a = (a_limb - a_limbs[k]) >> 16; - let carry_e = (e_limb - e_limbs[k]) >> 16; - cols.work_vars.carry_a[j][k] = F::from_canonical_u32(carry_a); - cols.work_vars.carry_e[j][k] = F::from_canonical_u32(carry_e); - bitwise_lookup_chip.request_range(carry_a, carry_e); - } - - // update working variables - work_vars[7] = work_vars[6]; - work_vars[6] = work_vars[5]; - work_vars[5] = work_vars[4]; - work_vars[4] = e; - work_vars[3] = work_vars[2]; - work_vars[2] = work_vars[1]; - work_vars[1] = work_vars[0]; - work_vars[0] = a; - } - - // filling w_3 and intermed_4 here and the rest later - if i > 0 { - for j in 0..SHA256_ROUNDS_PER_ROW { - let idx = i * SHA256_ROUNDS_PER_ROW + j; - let w_4 = u32_into_u16s(message_schedule[idx - 4]); - let sig_0_w_3 = u32_into_u16s(small_sig0(message_schedule[idx - 3])); - cols.schedule_helper.intermed_4[j] = - array::from_fn(|k| F::from_canonical_u32(w_4[k] + sig_0_w_3[k])); - if j < SHA256_ROUNDS_PER_ROW - 1 { - let w_3 = message_schedule[idx - 3]; - cols.schedule_helper.w_3[j] = - u32_into_u16s(w_3).map(F::from_canonical_u32); - } - } - } - } - // generate the digest row - else { - let cols: &mut Sha256DigestCols = - row[get_range(trace_start_col, SHA256_DIGEST_WIDTH)].borrow_mut(); - for j in 0..SHA256_ROUNDS_PER_ROW - 1 { - let w_3 = message_schedule[i * SHA256_ROUNDS_PER_ROW + j - 3]; - cols.schedule_helper.w_3[j] = u32_into_u16s(w_3).map(F::from_canonical_u32); - } - cols.flags.is_round_row = F::ZERO; - cols.flags.is_first_4_rows = F::ZERO; - cols.flags.is_digest_row = F::ONE; - cols.flags.is_last_block = F::from_bool(is_last_block); - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, 16).map(F::from_canonical_u32); - cols.flags.global_block_idx = F::from_canonical_u32(global_block_idx); - - cols.flags.local_block_idx = F::from_canonical_u32(local_block_idx); - let final_hash: [u32; SHA256_HASH_WORDS] = - array::from_fn(|i| work_vars[i].wrapping_add(prev_hash[i])); - let final_hash_limbs: [[u8; SHA256_WORD_U8S]; SHA256_HASH_WORDS] = - array::from_fn(|i| final_hash[i].to_le_bytes()); - // need to ensure final hash limbs are bytes, in order for - // prev_hash[i] + work_vars[i] == final_hash[i] - // to be constrained correctly - for word in final_hash_limbs.iter() { - for chunk in word.chunks(2) { - bitwise_lookup_chip.request_range(chunk[0] as u32, chunk[1] as u32); - } - } - cols.final_hash = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u8(final_hash_limbs[i][j])) - }); - cols.prev_hash = prev_hash.map(|f| u32_into_u16s(f).map(F::from_canonical_u32)); - let hash = if is_last_block { - SHA256_H.map(u32_into_bits_field::) - } else { - cols.final_hash - .map(|f| u32::from_le_bytes(f.map(|x| x.as_canonical_u32() as u8))) - .map(u32_into_bits_field::) - }; - - for i in 0..SHA256_ROUNDS_PER_ROW { - cols.hash.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; - cols.hash.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3]; - } - } - } - - for i in 0..SHA256_ROWS_PER_BLOCK - 1 { - let rows = &mut trace[i * trace_width..(i + 2) * trace_width]; - let (local, next) = rows.split_at_mut(trace_width); - let local_cols: &mut Sha256RoundCols = - local[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - let next_cols: &mut Sha256RoundCols = - next[get_range(trace_start_col, SHA256_ROUND_WIDTH)].borrow_mut(); - if i > 0 { - for j in 0..SHA256_ROUNDS_PER_ROW { - next_cols.schedule_helper.intermed_8[j] = - local_cols.schedule_helper.intermed_4[j]; - if (2..SHA256_ROWS_PER_BLOCK - 3).contains(&i) { - next_cols.schedule_helper.intermed_12[j] = - local_cols.schedule_helper.intermed_8[j]; - } - } - } - if i == SHA256_ROWS_PER_BLOCK - 2 { - // `next` is a digest row. - // Fill in `carry_a` and `carry_e` with dummy values so the constraints on `a` and - // `e` hold. - Self::generate_carry_ae(local_cols, next_cols); - // Fill in row 16's `intermed_4` with dummy values so the message schedule - // constraints holds on that row - Self::generate_intermed_4(local_cols, next_cols); - } - if i <= 2 { - // i is in 0..3. - // Fill in `local.intermed_12` with dummy values so the message schedule constraints - // hold on rows 1..4. - Self::generate_intermed_12(local_cols, next_cols); - } - } - } - - /// This function will fill in the cells that we couldn't do during the first pass. - /// This function should be called only after `generate_block_trace` was called for all blocks - /// And [`Self::generate_default_row`] is called for all invalid rows - /// Will populate the missing values of `trace`, where the width of the trace is `trace_width` - /// and the starting column for the `Sha256Air` is `trace_start_col`. - /// Note: `trace` needs to be the rows 1..17 of a block and the first row of the next block - pub fn generate_missing_cells( - &self, - trace: &mut [F], - trace_width: usize, - trace_start_col: usize, - ) { - // Here row_17 = next blocks row 0 - let rows_15_17 = &mut trace[14 * trace_width..17 * trace_width]; - let (row_15, row_16_17) = rows_15_17.split_at_mut(trace_width); - let (row_16, row_17) = row_16_17.split_at_mut(trace_width); - let cols_15: &mut Sha256RoundCols = - row_15[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - let cols_16: &mut Sha256RoundCols = - row_16[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - let cols_17: &mut Sha256RoundCols = - row_17[trace_start_col..trace_start_col + SHA256_ROUND_WIDTH].borrow_mut(); - // Fill in row 15's `intermed_12` with dummy values so the message schedule constraints - // holds on row 16 - Self::generate_intermed_12(cols_15, cols_16); - // Fill in row 16's `intermed_12` with dummy values so the message schedule constraints - // holds on the next block's row 0 - Self::generate_intermed_12(cols_16, cols_17); - // Fill in row 0's `intermed_4` with dummy values so the message schedule constraints holds - // on that row - Self::generate_intermed_4(cols_16, cols_17); - } - - /// Fills the `cols` as a padding row - /// Note: we still need to correctly fill in the hash values, carries and intermeds - pub fn generate_default_row( - self: &Sha256FillerHelper, - cols: &mut Sha256RoundCols, - ) { - cols.flags.row_idx = - get_flag_pt_array(&self.row_idx_encoder, 17).map(F::from_canonical_u32); - - let hash = SHA256_H.map(u32_into_bits_field::); - - for i in 0..SHA256_ROUNDS_PER_ROW { - cols.work_vars.a[i] = hash[SHA256_ROUNDS_PER_ROW - i - 1]; - cols.work_vars.e[i] = hash[SHA256_ROUNDS_PER_ROW - i + 3]; - } - - cols.work_vars.carry_a = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_A[i][j])) - }); - cols.work_vars.carry_e = array::from_fn(|i| { - array::from_fn(|j| F::from_canonical_u32(SHA256_INVALID_CARRY_E[i][j])) - }); - } - - /// The following functions do the calculations in native field since they will be called on - /// padding rows which can overflow and we need to make sure it matches the AIR constraints - /// Puts the correct carrys in the `next_row`, the resulting carrys can be out of bound - fn generate_carry_ae( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, - ) { - let a = [local_cols.work_vars.a, next_cols.work_vars.a].concat(); - let e = [local_cols.work_vars.e, next_cols.work_vars.e].concat(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let cur_a = a[i + 4]; - let sig_a = big_sig0_field::(&a[i + 3]); - let maj_abc = maj_field::(&a[i + 3], &a[i + 2], &a[i + 1]); - let d = a[i]; - let cur_e = e[i + 4]; - let sig_e = big_sig1_field::(&e[i + 3]); - let ch_efg = ch_field::(&e[i + 3], &e[i + 2], &e[i + 1]); - let h = e[i]; - - let t1 = [h, sig_e, ch_efg]; - let t2 = [sig_a, maj_abc]; - for j in 0..SHA256_WORD_U16S { - let t1_limb_sum = t1.iter().fold(F::ZERO, |acc, x| { - acc + compose::(&x[j * 16..(j + 1) * 16], 1) - }); - let t2_limb_sum = t2.iter().fold(F::ZERO, |acc, x| { - acc + compose::(&x[j * 16..(j + 1) * 16], 1) - }); - let d_limb = compose::(&d[j * 16..(j + 1) * 16], 1); - let cur_a_limb = compose::(&cur_a[j * 16..(j + 1) * 16], 1); - let cur_e_limb = compose::(&cur_e[j * 16..(j + 1) * 16], 1); - let sum = d_limb - + t1_limb_sum - + if j == 0 { - F::ZERO - } else { - next_cols.work_vars.carry_e[i][j - 1] - } - - cur_e_limb; - let carry_e = sum * (F::from_canonical_u32(1 << 16).inverse()); - - let sum = t1_limb_sum - + t2_limb_sum - + if j == 0 { - F::ZERO - } else { - next_cols.work_vars.carry_a[i][j - 1] - } - - cur_a_limb; - let carry_a = sum * (F::from_canonical_u32(1 << 16).inverse()); - next_cols.work_vars.carry_e[i][j] = carry_e; - next_cols.work_vars.carry_a[i][j] = carry_a; - } - } - } - - /// Puts the correct intermed_4 in the `next_row` - fn generate_intermed_4( - local_cols: &Sha256RoundCols, - next_cols: &mut Sha256RoundCols, - ) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - let sig_w = small_sig0_field::(&w[i + 1]); - let sig_w_limbs: [F; SHA256_WORD_U16S] = - array::from_fn(|j| compose::(&sig_w[j * 16..(j + 1) * 16], 1)); - for (j, sig_w_limb) in sig_w_limbs.iter().enumerate() { - next_cols.schedule_helper.intermed_4[i][j] = w_limbs[i][j] + *sig_w_limb; - } - } - } - - /// Puts the needed intermed_12 in the `local_row` - fn generate_intermed_12( - local_cols: &mut Sha256RoundCols, - next_cols: &Sha256RoundCols, - ) { - let w = [local_cols.message_schedule.w, next_cols.message_schedule.w].concat(); - let w_limbs: Vec<[F; SHA256_WORD_U16S]> = w - .iter() - .map(|x| array::from_fn(|i| compose::(&x[i * 16..(i + 1) * 16], 1))) - .collect(); - for i in 0..SHA256_ROUNDS_PER_ROW { - // sig_1(w_{t-2}) - let sig_w_2: [F; SHA256_WORD_U16S] = array::from_fn(|j| { - compose::(&small_sig1_field::(&w[i + 2])[j * 16..(j + 1) * 16], 1) - }); - // w_{t-7} - let w_7 = if i < 3 { - local_cols.schedule_helper.w_3[i] - } else { - w_limbs[i - 3] - }; - // w_t - let w_cur = w_limbs[i + 4]; - for j in 0..SHA256_WORD_U16S { - let carry = next_cols.message_schedule.carry_or_buffer[i][j * 2] - + F::TWO * next_cols.message_schedule.carry_or_buffer[i][j * 2 + 1]; - let sum = sig_w_2[j] + w_7[j] - carry * F::from_canonical_u32(1 << 16) - w_cur[j] - + if j > 0 { - next_cols.message_schedule.carry_or_buffer[i][j * 2 - 2] - + F::from_canonical_u32(2) - * next_cols.message_schedule.carry_or_buffer[i][j * 2 - 1] - } else { - F::ZERO - }; - local_cols.schedule_helper.intermed_12[i][j] = -sum; - } - } - } -} - -/// Generates a trace for a standalone SHA256 computation (currently only used for testing) -/// `records` consists of pairs of `(input_block, is_last_block)`. -pub fn generate_trace( - step: &Sha256FillerHelper, - bitwise_lookup_chip: &BitwiseOperationLookupChip<8>, - width: usize, - records: Vec<([u8; SHA256_BLOCK_U8S], bool)>, -) -> RowMajorMatrix { - let non_padded_height = records.len() * SHA256_ROWS_PER_BLOCK; - let height = next_power_of_two_or_zero(non_padded_height); - let mut values = F::zero_vec(height * width); - - struct BlockContext { - prev_hash: [u32; 8], - local_block_idx: u32, - global_block_idx: u32, - input: [u8; SHA256_BLOCK_U8S], - is_last_block: bool, - } - let mut block_ctx: Vec = Vec::with_capacity(records.len()); - let mut prev_hash = SHA256_H; - let mut local_block_idx = 0; - let mut global_block_idx = 1; - for (input, is_last_block) in records { - block_ctx.push(BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - }); - global_block_idx += 1; - if is_last_block { - local_block_idx = 0; - prev_hash = SHA256_H; - } else { - local_block_idx += 1; - prev_hash = Sha256FillerHelper::get_block_hash(&prev_hash, input); - } - } - // first pass - values - .par_chunks_exact_mut(width * SHA256_ROWS_PER_BLOCK) - .zip(block_ctx) - .for_each(|(block, ctx)| { - let BlockContext { - prev_hash, - local_block_idx, - global_block_idx, - input, - is_last_block, - } = ctx; - let input_words = array::from_fn(|i| { - limbs_into_u32::(array::from_fn(|j| { - input[(i + 1) * SHA256_WORD_U8S - j - 1] as u32 - })) - }); - step.generate_block_trace( - block, - width, - 0, - &input_words, - bitwise_lookup_chip, - &prev_hash, - is_last_block, - global_block_idx, - local_block_idx, - ); - }); - // second pass: padding rows - values[width * non_padded_height..] - .par_chunks_mut(width) - .for_each(|row| { - let cols: &mut Sha256RoundCols = row.borrow_mut(); - step.generate_default_row(cols); - }); - // second pass: non-padding rows - values[width..] - .par_chunks_mut(width * SHA256_ROWS_PER_BLOCK) - .take(non_padded_height / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - step.generate_missing_cells(chunk, width, 0); - }); - RowMajorMatrix::new(values, width) -} diff --git a/crates/circuits/sha256-air/src/utils.rs b/crates/circuits/sha256-air/src/utils.rs deleted file mode 100644 index ba598f2604..0000000000 --- a/crates/circuits/sha256-air/src/utils.rs +++ /dev/null @@ -1,271 +0,0 @@ -use std::array; - -pub use openvm_circuit_primitives::utils::compose; -use openvm_circuit_primitives::{ - encoder::Encoder, - utils::{not, select}, -}; -use openvm_stark_backend::{p3_air::AirBuilder, p3_field::FieldAlgebra}; - -use super::{Sha256DigestCols, Sha256RoundCols}; - -// ==== Do not change these constants! ==== -/// Number of bits in a SHA256 word -pub const SHA256_WORD_BITS: usize = 32; -/// Number of 16-bit limbs in a SHA256 word -pub const SHA256_WORD_U16S: usize = SHA256_WORD_BITS / 16; -/// Number of 8-bit limbs in a SHA256 word -pub const SHA256_WORD_U8S: usize = SHA256_WORD_BITS / 8; -/// Number of words in a SHA256 block -pub const SHA256_BLOCK_WORDS: usize = 16; -/// Number of cells in a SHA256 block -pub const SHA256_BLOCK_U8S: usize = SHA256_BLOCK_WORDS * SHA256_WORD_U8S; -/// Number of bits in a SHA256 block -pub const SHA256_BLOCK_BITS: usize = SHA256_BLOCK_WORDS * SHA256_WORD_BITS; -/// Number of rows per block -pub const SHA256_ROWS_PER_BLOCK: usize = 17; -/// Number of rounds per row -pub const SHA256_ROUNDS_PER_ROW: usize = 4; -/// Number of words in a SHA256 hash -pub const SHA256_HASH_WORDS: usize = 8; -/// Number of vars needed to encode the row index with [Encoder] -pub const SHA256_ROW_VAR_CNT: usize = 5; -/// Width of the Sha256RoundCols -pub const SHA256_ROUND_WIDTH: usize = Sha256RoundCols::::width(); -/// Width of the Sha256DigestCols -pub const SHA256_DIGEST_WIDTH: usize = Sha256DigestCols::::width(); -/// Size of the buffer of the first 4 rows of a block (each row's size) -pub const SHA256_BUFFER_SIZE: usize = SHA256_ROUNDS_PER_ROW * SHA256_WORD_U16S * 2; -/// Width of the Sha256Cols -pub const SHA256_WIDTH: usize = if SHA256_ROUND_WIDTH > SHA256_DIGEST_WIDTH { - SHA256_ROUND_WIDTH -} else { - SHA256_DIGEST_WIDTH -}; -/// We can notice that `carry_a`'s and `carry_e`'s are always the same on invalid rows -/// To optimize the trace generation of invalid rows, we have those values precomputed here -pub(crate) const SHA256_INVALID_CARRY_A: [[u32; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW] = [ - [1230919683, 1162494304], - [266373122, 1282901987], - [1519718403, 1008990871], - [923381762, 330807052], -]; -pub(crate) const SHA256_INVALID_CARRY_E: [[u32; SHA256_WORD_U16S]; SHA256_ROUNDS_PER_ROW] = [ - [204933122, 1994683449], - [443873282, 1544639095], - [719953922, 1888246508], - [194580482, 1075725211], -]; -/// SHA256 constant K's -pub const SHA256_K: [u32; 64] = [ - 0x428a2f98, 0x71374491, 0xb5c0fbcf, 0xe9b5dba5, 0x3956c25b, 0x59f111f1, 0x923f82a4, 0xab1c5ed5, - 0xd807aa98, 0x12835b01, 0x243185be, 0x550c7dc3, 0x72be5d74, 0x80deb1fe, 0x9bdc06a7, 0xc19bf174, - 0xe49b69c1, 0xefbe4786, 0x0fc19dc6, 0x240ca1cc, 0x2de92c6f, 0x4a7484aa, 0x5cb0a9dc, 0x76f988da, - 0x983e5152, 0xa831c66d, 0xb00327c8, 0xbf597fc7, 0xc6e00bf3, 0xd5a79147, 0x06ca6351, 0x14292967, - 0x27b70a85, 0x2e1b2138, 0x4d2c6dfc, 0x53380d13, 0x650a7354, 0x766a0abb, 0x81c2c92e, 0x92722c85, - 0xa2bfe8a1, 0xa81a664b, 0xc24b8b70, 0xc76c51a3, 0xd192e819, 0xd6990624, 0xf40e3585, 0x106aa070, - 0x19a4c116, 0x1e376c08, 0x2748774c, 0x34b0bcb5, 0x391c0cb3, 0x4ed8aa4a, 0x5b9cca4f, 0x682e6ff3, - 0x748f82ee, 0x78a5636f, 0x84c87814, 0x8cc70208, 0x90befffa, 0xa4506ceb, 0xbef9a3f7, 0xc67178f2, -]; - -/// SHA256 initial hash values -pub const SHA256_H: [u32; 8] = [ - 0x6a09e667, 0xbb67ae85, 0x3c6ef372, 0xa54ff53a, 0x510e527f, 0x9b05688c, 0x1f83d9ab, 0x5be0cd19, -]; - -/// Returns the number of blocks required to hash a message of length `len` -pub fn get_sha256_num_blocks(len: u32) -> u32 { - // need to pad with one 1 bit, 64 bits for the message length and then pad until the length - // is divisible by [SHA256_BLOCK_BITS] - ((len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS) as u32 -} - -/// Convert a u32 into a list of bits in little endian then convert each bit into a field element -pub fn u32_into_bits_field(num: u32) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| F::from_bool((num >> i) & 1 == 1)) -} - -/// Convert a u32 into a an array of 2 16-bit limbs in little endian -pub fn u32_into_u16s(num: u32) -> [u32; 2] { - [num & 0xffff, num >> 16] -} - -/// Convert a list of limbs in little endian into a u32 -pub fn limbs_into_u32(limbs: [u32; NUM_LIMBS]) -> u32 { - let limb_bits = 32 / NUM_LIMBS; - limbs - .iter() - .rev() - .fold(0, |acc, &limb| (acc << limb_bits) | limb) -} - -/// Rotates `bits` right by `n` bits, assumes `bits` is in little-endian -#[inline] -pub(crate) fn rotr( - bits: &[impl Into + Clone; SHA256_WORD_BITS], - n: usize, -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| bits[(i + n) % SHA256_WORD_BITS].clone().into()) -} - -/// Shifts `bits` right by `n` bits, assumes `bits` is in little-endian -#[inline] -pub(crate) fn shr( - bits: &[impl Into + Clone; SHA256_WORD_BITS], - n: usize, -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| { - if i + n < SHA256_WORD_BITS { - bits[i + n].clone().into() - } else { - F::ZERO - } - }) -} - -/// Computes x ^ y ^ z, where x, y, z are assumed to be boolean -#[inline] -pub(crate) fn xor_bit( - x: impl Into, - y: impl Into, - z: impl Into, -) -> F { - let (x, y, z) = (x.into(), y.into(), z.into()); - (x.clone() * y.clone() * z.clone()) - + (x.clone() * not::(y.clone()) * not::(z.clone())) - + (not::(x.clone()) * y.clone() * not::(z.clone())) - + (not::(x) * not::(y) * z) -} - -/// Computes x ^ y ^ z, where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn xor( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| xor_bit(x[i].clone(), y[i].clone(), z[i].clone())) -} - -/// Choose function from SHA256 -#[inline] -pub fn ch(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ ((!x) & z) -} - -/// Computes Ch(x,y,z), where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn ch_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| select(x[i].clone(), y[i].clone(), z[i].clone())) -} - -/// Majority function from SHA256 -pub fn maj(x: u32, y: u32, z: u32) -> u32 { - (x & y) ^ (x & z) ^ (y & z) -} - -/// Computes Maj(x,y,z), where x, y, z are [SHA256_WORD_BITS] bit numbers -#[inline] -pub(crate) fn maj_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], - y: &[impl Into + Clone; SHA256_WORD_BITS], - z: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - array::from_fn(|i| { - let (x, y, z) = ( - x[i].clone().into(), - y[i].clone().into(), - z[i].clone().into(), - ); - x.clone() * y.clone() + x.clone() * z.clone() + y.clone() * z.clone() - F::TWO * x * y * z - }) -} - -/// Big sigma_0 function from SHA256 -pub fn big_sig0(x: u32) -> u32 { - x.rotate_right(2) ^ x.rotate_right(13) ^ x.rotate_right(22) -} - -/// Computes BigSigma0(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn big_sig0_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 2), &rotr::(x, 13), &rotr::(x, 22)) -} - -/// Big sigma_1 function from SHA256 -pub fn big_sig1(x: u32) -> u32 { - x.rotate_right(6) ^ x.rotate_right(11) ^ x.rotate_right(25) -} - -/// Computes BigSigma1(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn big_sig1_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 6), &rotr::(x, 11), &rotr::(x, 25)) -} - -/// Small sigma_0 function from SHA256 -pub fn small_sig0(x: u32) -> u32 { - x.rotate_right(7) ^ x.rotate_right(18) ^ (x >> 3) -} - -/// Computes SmallSigma0(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn small_sig0_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 7), &rotr::(x, 18), &shr::(x, 3)) -} - -/// Small sigma_1 function from SHA256 -pub fn small_sig1(x: u32) -> u32 { - x.rotate_right(17) ^ x.rotate_right(19) ^ (x >> 10) -} - -/// Computes SmallSigma1(x), where x is a [SHA256_WORD_BITS] bit number in little-endian -#[inline] -pub(crate) fn small_sig1_field( - x: &[impl Into + Clone; SHA256_WORD_BITS], -) -> [F; SHA256_WORD_BITS] { - xor(&rotr::(x, 17), &rotr::(x, 19), &shr::(x, 10)) -} - -/// Wrapper of `get_flag_pt` to get the flag pointer as an array -pub fn get_flag_pt_array(encoder: &Encoder, flag_idx: usize) -> [u32; N] { - encoder.get_flag_pt(flag_idx).try_into().unwrap() -} - -/// Constrain the addition of [SHA256_WORD_BITS] bit words in 16-bit limbs -/// It takes in the terms some in bits some in 16-bit limbs, -/// the expected sum in bits and the carries -pub fn constraint_word_addition( - builder: &mut AB, - terms_bits: &[&[impl Into + Clone; SHA256_WORD_BITS]], - terms_limb: &[&[impl Into + Clone; SHA256_WORD_U16S]], - expected_sum: &[impl Into + Clone; SHA256_WORD_BITS], - carries: &[impl Into + Clone; SHA256_WORD_U16S], -) { - for i in 0..SHA256_WORD_U16S { - let mut limb_sum = if i == 0 { - AB::Expr::ZERO - } else { - carries[i - 1].clone().into() - }; - for term in terms_bits { - limb_sum += compose::(&term[i * 16..(i + 1) * 16], 1); - } - for term in terms_limb { - limb_sum += term[i].clone().into(); - } - let expected_sum_limb = compose::(&expected_sum[i * 16..(i + 1) * 16], 1) - + carries[i].clone().into() * AB::Expr::from_canonical_u32(1 << 16); - builder.assert_eq(limb_sum, expected_sum_limb); - } -} diff --git a/crates/sdk/Cargo.toml b/crates/sdk/Cargo.toml index 60b2a89769..bae5f1e3af 100644 --- a/crates/sdk/Cargo.toml +++ b/crates/sdk/Cargo.toml @@ -18,8 +18,8 @@ openvm-ecc-circuit = { workspace = true } openvm-ecc-transpiler = { workspace = true } openvm-keccak256-circuit = { workspace = true } openvm-keccak256-transpiler = { workspace = true } -openvm-sha256-circuit = { workspace = true } -openvm-sha256-transpiler = { workspace = true } +openvm-sha2-circuit = { workspace = true } +openvm-sha2-transpiler = { workspace = true } openvm-pairing-circuit = { workspace = true } openvm-pairing-transpiler = { workspace = true } openvm-native-circuit = { workspace = true } @@ -86,7 +86,7 @@ tco = [ "openvm-circuit/tco", "openvm-rv32im-circuit/tco", "openvm-native-circuit/tco", - "openvm-sha256-circuit/tco", + "openvm-sha2-circuit/tco", "openvm-keccak256-circuit/tco", "openvm-bigint-circuit/tco", "openvm-algebra-circuit/tco", @@ -118,7 +118,7 @@ cuda = [ "openvm-bigint-circuit/cuda", "openvm-ecc-circuit/cuda", "openvm-keccak256-circuit/cuda", - "openvm-sha256-circuit/cuda", + "openvm-sha2-circuit/cuda", "openvm-pairing-circuit/cuda", "openvm-native-circuit/cuda", "openvm-rv32im-circuit/cuda", diff --git a/crates/sdk/src/config/global.rs b/crates/sdk/src/config/global.rs index 9699b1ed34..d3c4ec7297 100644 --- a/crates/sdk/src/config/global.rs +++ b/crates/sdk/src/config/global.rs @@ -33,8 +33,8 @@ use openvm_rv32im_circuit::{ use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; -use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha2CpuProverExt}; -use openvm_sha256_transpiler::Sha256TranspilerExtension; +use openvm_sha2_circuit::{Sha2, Sha2CpuProverExt, Sha2Executor}; +use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_backend::{ config::{StarkGenericConfig, Val}, engine::StarkEngine, @@ -55,7 +55,7 @@ cfg_if::cfg_if! { use openvm_keccak256_circuit::Keccak256GpuProverExt; use openvm_native_circuit::NativeGpuProverExt; use openvm_rv32im_circuit::Rv32ImGpuProverExt; - use openvm_sha256_circuit::Sha256GpuProverExt; + use openvm_sha2_circuit::Sha2GpuProverExt; pub use SdkVmGpuBuilder as SdkVmBuilder; } else { pub use SdkVmCpuBuilder as SdkVmBuilder; @@ -81,7 +81,7 @@ pub struct SdkVmConfig { pub rv32i: Option, pub io: Option, pub keccak: Option, - pub sha256: Option, + pub sha2: Option, pub native: Option, pub castf: Option, @@ -118,7 +118,7 @@ impl SdkVmConfig { .rv32m(Default::default()) .io(Default::default()) .keccak(Default::default()) - .sha256(Default::default()) + .sha2(Default::default()) .bigint(Default::default()) .modular(ModularExtension::new(vec![ bn_config.modulus.clone(), @@ -199,8 +199,8 @@ impl TranspilerConfig for SdkVmConfig { if self.keccak.is_some() { transpiler = transpiler.with_extension(Keccak256TranspilerExtension); } - if self.sha256.is_some() { - transpiler = transpiler.with_extension(Sha256TranspilerExtension); + if self.sha2.is_some() { + transpiler = transpiler.with_extension(Sha2TranspilerExtension); } if self.native.is_some() { transpiler = transpiler.with_extension(LongFormTranspilerExtension); @@ -269,7 +269,7 @@ impl SdkVmConfig { let rv32i = config.rv32i.map(|_| Rv32I); let io = config.io.map(|_| Rv32Io); let keccak = config.keccak.map(|_| Keccak256); - let sha256 = config.sha256.map(|_| Sha256); + let sha2 = config.sha2.map(|_| Sha2); let native = config.native.map(|_| Native); let castf = config.castf.map(|_| CastFExtension); let rv32m = config.rv32m; @@ -284,7 +284,7 @@ impl SdkVmConfig { rv32i, io, keccak, - sha256, + sha2, native, castf, rv32m, @@ -315,8 +315,8 @@ pub struct SdkVmConfigInner { pub io: Option, #[extension(executor = "Keccak256Executor")] pub keccak: Option, - #[extension(executor = "Sha256Executor")] - pub sha256: Option, + #[extension(executor = "Sha2Executor")] + pub sha2: Option, #[extension(executor = "NativeExecutor")] pub native: Option, #[extension(executor = "CastFExtensionExecutor")] @@ -392,8 +392,8 @@ where if let Some(keccak) = &config.keccak { VmProverExtension::::extend_prover(&Keccak256CpuProverExt, keccak, inventory)?; } - if let Some(sha256) = &config.sha256 { - VmProverExtension::::extend_prover(&Sha2CpuProverExt, sha256, inventory)?; + if let Some(sha2) = &config.sha2 { + VmProverExtension::::extend_prover(&Sha2CpuProverExt, sha2, inventory)?; } if let Some(native) = &config.native { VmProverExtension::::extend_prover(&NativeCpuProverExt, native, inventory)?; @@ -456,8 +456,8 @@ impl VmBuilder for SdkVmGpuBuilder { if let Some(keccak) = &config.keccak { VmProverExtension::::extend_prover(&Keccak256GpuProverExt, keccak, inventory)?; } - if let Some(sha256) = &config.sha256 { - VmProverExtension::::extend_prover(&Sha256GpuProverExt, sha256, inventory)?; + if let Some(sha2) = &config.sha2 { + VmProverExtension::::extend_prover(&Sha2GpuProverExt, sha2, inventory)?; } if let Some(native) = &config.native { VmProverExtension::::extend_prover(&NativeGpuProverExt, native, inventory)?; @@ -566,8 +566,8 @@ impl From for UnitStruct { } } -impl From for UnitStruct { - fn from(_: Sha256) -> Self { +impl From for UnitStruct { + fn from(_: Sha2) -> Self { UnitStruct {} } } @@ -592,7 +592,7 @@ struct SdkVmConfigWithDefaultDeser { pub rv32i: Option, pub io: Option, pub keccak: Option, - pub sha256: Option, + pub sha2: Option, pub native: Option, pub castf: Option, @@ -611,7 +611,7 @@ impl From for SdkVmConfig { rv32i: config.rv32i, io: config.io, keccak: config.keccak, - sha256: config.sha256, + sha2: config.sha2, native: config.native, castf: config.castf, rv32m: config.rv32m, diff --git a/crates/vm/src/arch/execution_control.rs b/crates/vm/src/arch/execution_control.rs new file mode 100644 index 0000000000..5aba19cb3a --- /dev/null +++ b/crates/vm/src/arch/execution_control.rs @@ -0,0 +1,69 @@ +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{ExecutionError, VmChipComplex, VmConfig, VmSegmentState}; + +/// Trait for execution control, determining segmentation and stopping conditions +/// Invariants: +/// - `ExecutionControl` should be stateless. +/// - For E3/E4, `ExecutionControl` is for a specific execution and cannot be used for another +/// execution with different inputs or segmentation criteria. +pub trait ExecutionControl +where + F: PrimeField32, + VC: VmConfig, +{ + /// Host context + type Ctx; + + fn initialize_context(&self) -> Self::Ctx; + + /// Determines if execution should suspend + fn should_suspend( + &self, + state: &mut VmSegmentState, + chip_complex: &VmChipComplex, + ) -> bool; + + /// Called before execution begins + fn on_start( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ); + + /// Called after suspend or terminate + fn on_suspend_or_terminate( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: Option, + ); + + fn on_suspend( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) { + self.on_suspend_or_terminate(state, chip_complex, None); + } + + fn on_terminate( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: u32, + ) { + self.on_suspend_or_terminate(state, chip_complex, Some(exit_code)); + } + + /// Execute a single instruction + fn execute_instruction( + &self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32; +} diff --git a/crates/vm/src/arch/execution_mode/e1.rs b/crates/vm/src/arch/execution_mode/e1.rs new file mode 100644 index 0000000000..49cda03dad --- /dev/null +++ b/crates/vm/src/arch/execution_mode/e1.rs @@ -0,0 +1,32 @@ +use crate::arch::{execution_mode::E1ExecutionCtx, VmSegmentState}; + +pub struct E1Ctx { + instret_end: u64, +} + +impl E1Ctx { + pub fn new(instret_end: Option) -> Self { + E1Ctx { + instret_end: if let Some(end) = instret_end { + end + } else { + u64::MAX + }, + } + } +} + +impl Default for E1Ctx { + fn default() -> Self { + Self::new(None) + } +} + +impl E1ExecutionCtx for E1Ctx { + #[inline(always)] + fn on_memory_operation(&mut self, _address_space: u32, _ptr: u32, _size: u32) {} + #[inline(always)] + fn should_suspend(vm_state: &mut VmSegmentState) -> bool { + vm_state.instret >= vm_state.ctx.instret_end + } +} diff --git a/crates/vm/src/arch/execution_mode/tracegen.rs b/crates/vm/src/arch/execution_mode/tracegen.rs new file mode 100644 index 0000000000..6aec38e170 --- /dev/null +++ b/crates/vm/src/arch/execution_mode/tracegen.rs @@ -0,0 +1,97 @@ +use openvm_instructions::instruction::Instruction; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + arch::{ + execution_control::ExecutionControl, ExecutionError, ExecutionState, InstructionExecutor, + VmChipComplex, VmConfig, VmSegmentState, + }, + system::memory::INITIAL_TIMESTAMP, +}; + +#[derive(Default, derive_new::new)] +pub struct TracegenCtx { + pub instret_end: Option, +} + +#[derive(Default)] +pub struct TracegenExecutionControl; + +impl ExecutionControl for TracegenExecutionControl +where + F: PrimeField32, + VC: VmConfig, +{ + type Ctx = TracegenCtx; + + fn initialize_context(&self) -> Self::Ctx { + TracegenCtx { instret_end: None } + } + + fn should_suspend( + &self, + state: &mut VmSegmentState, + _chip_complex: &VmChipComplex, + ) -> bool { + state + .ctx + .instret_end + .is_some_and(|instret_end| state.instret >= instret_end) + } + + fn on_start( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + ) { + chip_complex + .connector_chip_mut() + .begin(ExecutionState::new(state.pc, INITIAL_TIMESTAMP + 1)); + } + + fn on_suspend_or_terminate( + &self, + state: &mut VmSegmentState, + chip_complex: &mut VmChipComplex, + exit_code: Option, + ) { + let timestamp = chip_complex.memory_controller().timestamp(); + chip_complex + .connector_chip_mut() + .end(ExecutionState::new(state.pc, timestamp), exit_code); + } + + /// Execute a single instruction + fn execute_instruction( + &self, + state: &mut VmSegmentState, + instruction: &Instruction, + chip_complex: &mut VmChipComplex, + ) -> Result<(), ExecutionError> + where + F: PrimeField32, + { + let timestamp = chip_complex.memory_controller().timestamp(); + + let &Instruction { opcode, .. } = instruction; + + if let Some(executor) = chip_complex.inventory.get_mut_executor(&opcode) { + let memory_controller = &mut chip_complex.base.memory_controller; + let new_state = executor.execute( + memory_controller, + &mut state.streams, + &mut state.rng, + instruction, + ExecutionState::new(state.pc, timestamp), + )?; + state.pc = new_state.pc; + } else { + return Err(ExecutionError::DisabledOperation { + pc: state.pc, + opcode, + }); + }; + + Ok(()) + } +} diff --git a/crates/vm/src/arch/record_arena.rs b/crates/vm/src/arch/record_arena.rs index ec12ea4b5f..f378665e70 100644 --- a/crates/vm/src/arch/record_arena.rs +++ b/crates/vm/src/arch/record_arena.rs @@ -2,7 +2,7 @@ use std::{ borrow::BorrowMut, io::Cursor, marker::PhantomData, - ptr::{copy_nonoverlapping, slice_from_raw_parts_mut}, + ptr::{copy_nonoverlapping, slice_from_raw_parts, slice_from_raw_parts_mut}, }; use openvm_circuit_primitives::utils::next_power_of_two_or_zero; @@ -277,6 +277,25 @@ where record } +/// Converts a field element slice into a record type. +/// This function transmutes the `&mut [F]` to raw bytes, +/// then uses the `CustomBorrow` trait to transmute to the desired record type `T`. +/// ## Safety +/// `slice` must satisfy the requirements of the `CustomBorrow` trait. +/// Note: this function is different from get_record_from_slice only in the type of the slice +/// parameter. Note that get_record_from_slice takes &mut &mut [F] so that you can borrow a record +/// and columns struct from the same slice. +pub unsafe fn get_record_from_slice_no_ref<'a, T, F, L>(slice: &'a mut [F], layout: L) -> T +where + [u8]: CustomBorrow<'a, T, L>, +{ + // The alignment of `[u8]` is always satisfiedƒ + let record_buffer = + &mut *slice_from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, size_of_val::<[F]>(slice)); + let record: T = record_buffer.custom_borrow(layout); + record +} + /// A trait that allows for custom implementation of `borrow` given the necessary information /// This is useful for record structs that have dynamic size pub trait CustomBorrow<'a, T, L> { diff --git a/crates/vm/src/arch/segmentation_strategy.rs b/crates/vm/src/arch/segmentation_strategy.rs new file mode 100644 index 0000000000..0336546626 --- /dev/null +++ b/crates/vm/src/arch/segmentation_strategy.rs @@ -0,0 +1,113 @@ +use std::sync::Arc; + +pub const DEFAULT_MAX_SEGMENT_LEN: usize = (1 << 22) - 100; +pub const DEFAULT_MAX_CELLS_IN_SEGMENT: usize = 2_000_000_000; // 2B + +pub trait SegmentationStrategy: + std::fmt::Debug + Send + Sync + std::panic::UnwindSafe + std::panic::RefUnwindSafe +{ + /// Whether the execution should segment based on the trace heights and cells. + /// + /// Air names are provided for debugging purposes. + fn should_segment( + &self, + air_names: &[String], + trace_heights: &[usize], + trace_cells: &[usize], + ) -> bool; + + /// A strategy that segments more aggressively than the current one. + /// + /// Called when `should_segment` results in a segment that is infeasible. Execution will be + /// re-run with the stricter segmentation strategy. + fn stricter_strategy(&self) -> Arc; + + /// Maximum height of any chip in a segment. + fn max_trace_height(&self) -> usize; + + /// Maximum number of cells in a segment. + fn max_cells(&self) -> usize; +} + +/// Default segmentation strategy: segment if any chip's height or cells exceed the limits. +#[derive(Debug, Clone)] +pub struct DefaultSegmentationStrategy { + max_segment_len: usize, + max_cells_in_segment: usize, +} + +impl Default for DefaultSegmentationStrategy { + fn default() -> Self { + Self { + max_segment_len: DEFAULT_MAX_SEGMENT_LEN, + max_cells_in_segment: DEFAULT_MAX_CELLS_IN_SEGMENT, + } + } +} + +impl DefaultSegmentationStrategy { + pub fn new_with_max_segment_len(max_segment_len: usize) -> Self { + Self { + max_segment_len, + max_cells_in_segment: DEFAULT_MAX_CELLS_IN_SEGMENT, + } + } + + pub fn new(max_segment_len: usize, max_cells_in_segment: usize) -> Self { + Self { + max_segment_len, + max_cells_in_segment, + } + } + + pub fn max_segment_len(&self) -> usize { + self.max_segment_len + } +} + +const SEGMENTATION_BACKOFF_FACTOR: usize = 4; + +impl SegmentationStrategy for DefaultSegmentationStrategy { + fn max_trace_height(&self) -> usize { + self.max_segment_len + } + + fn max_cells(&self) -> usize { + self.max_cells_in_segment + } + + fn should_segment( + &self, + air_names: &[String], + trace_heights: &[usize], + trace_cells: &[usize], + ) -> bool { + for (i, &height) in trace_heights.iter().enumerate() { + if height > self.max_segment_len { + tracing::info!( + "Should segment because chip {} (name: {}) has height {}", + i, + air_names[i], + height + ); + return true; + } + } + let total_cells: usize = trace_cells.iter().sum(); + if total_cells > self.max_cells_in_segment { + tracing::info!( + "Should segment because total cells across all chips is {}", + total_cells + ); + return true; + } + false + } + + fn stricter_strategy(&self) -> Arc { + Arc::new(Self { + max_segment_len: self.max_segment_len / SEGMENTATION_BACKOFF_FACTOR, + max_cells_in_segment: self.max_cells_in_segment / SEGMENTATION_BACKOFF_FACTOR, + }) + } +} diff --git a/crates/vm/src/arch/testing/cpu.rs b/crates/vm/src/arch/testing/cpu.rs index 70c374968c..cbaa30cb14 100644 --- a/crates/vm/src/arch/testing/cpu.rs +++ b/crates/vm/src/arch/testing/cpu.rs @@ -533,6 +533,25 @@ where self } + pub fn load_periphery_and_prank_trace( + mut self, + (air, chip): (A, C), + modify_trace: P, + ) -> Self + where + A: AnyRap + 'static, + C: Chip<(), CpuBackend>, + P: Fn(&mut RowMajorMatrix>), + { + let mut ctx = chip.generate_proving_ctx(()); + let trace: Arc>> = Option::take(&mut ctx.common_main).unwrap(); + let mut trace = Arc::into_inner(trace).unwrap(); + modify_trace(&mut trace); + ctx.common_main = Some(Arc::new(trace)); + self.air_ctxs.push((Arc::new(air), ctx)); + self + } + /// Given a function to produce an engine from the max trace height, /// runs a simple test on that engine pub fn test E>( diff --git a/crates/vm/src/system/memory/offline_checker/columns.rs b/crates/vm/src/system/memory/offline_checker/columns.rs index ef9821f859..63c114193d 100644 --- a/crates/vm/src/system/memory/offline_checker/columns.rs +++ b/crates/vm/src/system/memory/offline_checker/columns.rs @@ -62,7 +62,7 @@ impl MemoryWriteAuxCols { #[repr(C)] #[derive(Clone, Copy, Debug, AlignedBorrow)] pub struct MemoryReadAuxCols { - pub(in crate::system::memory) base: MemoryBaseAuxCols, + pub base: MemoryBaseAuxCols, } impl MemoryReadAuxCols { diff --git a/docs/crates/metrics.md b/docs/crates/metrics.md index 6fe3072add..3830f1919c 100644 --- a/docs/crates/metrics.md +++ b/docs/crates/metrics.md @@ -15,6 +15,9 @@ For a segment proof, the following metrics are collected: - `memory_finalize_time_ms` (gauge): The time at the end of preflight execution spent on memory finalization. - `trace_gen_time_ms` (gauge): The time to generate non-cached trace matrices from execution records. - If this is a segment in a VM with continuations enabled, a `segment: segment_idx` label is added to the metric. + - `memory_finalize_time_ms` (gauge): The time in trace generation spent on memory finalization. + - `boundary_finalize_time_ms` (gauge): The time in memory finalization spent on boundary finalization. + - `merkle_finalize_time_ms` (gauge): The time in memory finalization spent on merkle tree finalization. - All metrics collected by [`openvm-stark-backend`](https://github.com/openvm-org/stark-backend/blob/main/docs/metrics.md), in particular `stark_prove_excluding_trace_time_ms` (gauge). - The `total_proof_time_ms` of the proof is instrumented directly when possible. Otherwise, it is calculated as: - The sum `execute_preflight_time_ms + trace_gen_time_ms + stark_prove_excluding_trace_time_ms`. The `execute_metered_time_ms` is excluded for app proofs because it is not run on a per-segment basis. diff --git a/docs/vocs/docs/pages/book/acceleration-using-extensions/overview.mdx b/docs/vocs/docs/pages/book/acceleration-using-extensions/overview.mdx index edfac8d542..17edf81d98 100644 --- a/docs/vocs/docs/pages/book/acceleration-using-extensions/overview.mdx +++ b/docs/vocs/docs/pages/book/acceleration-using-extensions/overview.mdx @@ -3,7 +3,7 @@ OpenVM ships with a set of pre-built extensions maintained by the OpenVM team. Below, we highlight six of these extensions designed to accelerate common arithmetic and cryptographic operations that are notoriously expensive to execute. Some of these extensions have corresponding guest libraries which provide convenient, high-level interfaces for your guest program to interact with the extension. - [`openvm-keccak-guest`](/book/acceleration-using-extensions/keccak) - Keccak256 hash function. See the [Keccak256 guest library](/book/guest-libraries/keccak256) for usage details. -- [`openvm-sha256-guest`](/book/acceleration-using-extensions/sha-256) - SHA-256 hash function. See the [SHA2 guest library](/book/guest-libraries/sha2) for usage details. +- [`openvm-sha2-guest`](/book/acceleration-using-extensions/sha-2) - SHA-2 family of hash functions. See the [SHA-2 guest library](/book/guest-libraries/sha2) for usage details. - [`openvm-bigint-guest`](/book/acceleration-using-extensions/big-integer) - Big integer arithmetic for 256-bit signed and unsigned integers. See the [Ruint guest library](/book/guest-libraries/ruint) for using accelerated 256-bit integer ops in rust. - [`openvm-algebra-guest`](/book/acceleration-using-extensions/algebra) - Modular arithmetic and complex field extensions. - [`openvm-ecc-guest`](/book/acceleration-using-extensions/elliptic-curve-cryptography) - Elliptic curve cryptography. See the [K256](/book/guest-libraries/k256) and [P256](/book/guest-libraries/p256) guest libraries for using this extension over the respective curves. @@ -43,9 +43,7 @@ range_tuple_checker_sizes = [256, 8192] [app_vm_config.io] [app_vm_config.keccak] - -[app_vm_config.sha256] - +[app_vm_config.sha2] [app_vm_config.native] [app_vm_config.bigint] diff --git a/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-256.mdx b/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-2.mdx similarity index 52% rename from docs/vocs/docs/pages/book/acceleration-using-extensions/sha-256.mdx rename to docs/vocs/docs/pages/book/acceleration-using-extensions/sha-2.mdx index a4a7f46261..de845fe25f 100644 --- a/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-256.mdx +++ b/docs/vocs/docs/pages/book/acceleration-using-extensions/sha-2.mdx @@ -1,8 +1,8 @@ -# SHA-256 +# SHA-2 -The SHA-256 extension guest provides a function that is meant to be linked to other external libraries. The external libraries can use this function as a hook for the SHA-256 intrinsic. This is enabled only when the target is `zkvm`. +The SHA-2 extension guest provides functions that are meant to be linked to other external libraries. The external libraries can use these functions as a hook for SHA-2 intrinsics. This is enabled only when the target is `zkvm`. We support the SHA-256, SHA-512, and SHA-384 hash functions. -- `zkvm_sha256_impl(input: *const u8, len: usize, output: *mut u8)`: This function has `C` ABI. It takes in a pointer to the input, the length of the input, and a pointer to the output buffer. +- `zkvm_shaXXX_impl(input: *const u8, len: usize, output: *mut u8)` where XXX is one of `256`, `512`, or `384`. These functions have `C` ABI. They take in a pointer to the input, the length of the input, and a pointer to the output buffer. In the external library, you can do the following: @@ -31,5 +31,5 @@ fn sha256(input: &[u8]) -> [u8; 32] { For the guest program to build successfully add the following to your `.toml` file: ```toml -[app_vm_config.sha256] +[app_vm_config.sha2] ``` diff --git a/docs/vocs/docs/pages/book/advanced-usage/sdk.mdx b/docs/vocs/docs/pages/book/advanced-usage/sdk.mdx index 8a69d4231b..4a9fbc7842 100644 --- a/docs/vocs/docs/pages/book/advanced-usage/sdk.mdx +++ b/docs/vocs/docs/pages/book/advanced-usage/sdk.mdx @@ -38,7 +38,7 @@ Note that to use `Sdk::riscv32()` or `Sdk::standard()` the `app_vm_config` field ``` -Observe that this standard `openvm.toml` also enables normal Rust and `openvm::io` functions (via the `rv32i`, `rv32m`, and `io` extensions). `keccak` and `sha256` enable intrinsic instructions for the [Keccak](/book/acceleration-using-extensions/keccak) and [SHA-256](/book/acceleration-using-extensions/sha-256) hashes respectively, and `bigint` supports [Big Integer](/book/acceleration-using-extensions/big-integer) operations. +Observe that this standard `openvm.toml` also enables normal Rust and `openvm::io` functions (via the `rv32i`, `rv32m`, and `io` extensions). `keccak` and `sha2` enable intrinsic instructions for the [Keccak](/book/acceleration-using-extensions/keccak) and [SHA-2](/book/acceleration-using-extensions/sha-2) hashes respectively, and `bigint` supports [Big Integer](/book/acceleration-using-extensions/big-integer) operations. [Modular](/book/acceleration-using-extensions/algebra) operations for the BN254, Secp256k1 (i.e. K256), Secp256r1 (i.e. P256), and BLS12-381 curves' scalar and coordinate field moduli are also supported, as well as [Complex Field Extension](/book/acceleration-using-extensions/algebra#complex-field-extension) operations over the BN254 and BLS12-381 coordinate fields. [Elliptic Curve Cryptography](/book/acceleration-using-extensions/elliptic-curve-cryptography) operations are also supported for the BN254, Secp256k1, Secp256r1, and BLS12-381 curves, and [Elliptic Curve Pairing](/book/acceleration-using-extensions/elliptic-curve-pairing) checks are supported for the BN254 and BLS12-381 curves. diff --git a/docs/vocs/docs/pages/book/getting-started/introduction.mdx b/docs/vocs/docs/pages/book/getting-started/introduction.mdx index 6dc47c4d71..ebde9f8949 100644 --- a/docs/vocs/docs/pages/book/getting-started/introduction.mdx +++ b/docs/vocs/docs/pages/book/getting-started/introduction.mdx @@ -12,7 +12,7 @@ OpenVM is an open-source zero-knowledge virtual machine (zkVM) framework focused - RISC-V support via RV32IM - A native field arithmetic extension for proof recursion and aggregation - - The Keccak-256 and SHA2-256 hash functions + - The Keccak-256, SHA-256, SHA-512, and SHA-384 hash functions - Int256 arithmetic - Modular arithmetic over arbitrary fields - Elliptic curve operations, including multi-scalar multiplication and ECDSA signature verification, including for the secp256k1 and secp256r1 curves diff --git a/docs/vocs/docs/pages/book/guest-libraries/sha2.mdx b/docs/vocs/docs/pages/book/guest-libraries/sha2.mdx index 5027f4fdfd..8c2c649774 100644 --- a/docs/vocs/docs/pages/book/guest-libraries/sha2.mdx +++ b/docs/vocs/docs/pages/book/guest-libraries/sha2.mdx @@ -3,26 +3,33 @@ The OpenVM SHA-2 guest library provides access to a set of accelerated SHA-2 family hash functions. Currently, it supports the following: - SHA-256 +- SHA-512 +- SHA-384 -## SHA-256 - -Refer [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for more details on SHA-256. +Refer [here](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for more details on the SHA-2 family of hash functions. For SHA-256, the SHA2 guest library provides two functions for use in your guest code: - - `sha256(input: &[u8]) -> [u8; 32]`: Computes the SHA-256 hash of the input data and returns it as an array of 32 bytes. - `set_sha256(input: &[u8], output: &mut [u8; 32])`: Sets the output to the SHA-256 hash of the input data into the provided output buffer. -See the full example [here](https://github.com/openvm-org/openvm/blob/main/examples/sha256/src/main.rs). +For SHA-512, we provide: +- `sha512(input: &[u8]) -> [u8; 46]`: Computes the SHA-512 hash of the input data and returns it as an array of 64 bytes. +- `set_sha512(input: &[u8], output: &mut [u8; 64])`: Sets the output to the SHA-512 hash of the input data into the provided output buffer. + +For SHA-384, we provide: +- `sha384(input: &[u8]) -> [u8; 48]`: Computes the SHA-384 hash of the input data and returns it as an array of 48 bytes. +- `set_sha384(input: &[u8], output: &mut [u8; 48])`: Sets the output to the SHA-384 hash of the input data into the provided output buffer. + +See the full example [here](https://github.com/openvm-org/openvm/blob/feat/sha-512-new-execution/examples/sha2/src/main.rs). ### Example ```rust -// [!include ~/snippets/examples/sha256/src/main.rs:imports] -// [!include ~/snippets/examples/sha256/src/main.rs:main] +// [!include ~/snippets/examples/sha2/src/main.rs:imports] +// [!include ~/snippets/examples/sha2/src/main.rs:main] ``` -To be able to import the `sha256` function, add the following to your `Cargo.toml` file: +To be able to import the `shaXXX` functions and run the example, add the following to your `Cargo.toml` file: ```toml openvm-sha2 = { git = "https://github.com/openvm-org/openvm.git", tag = "v1.4.0" } @@ -34,4 +41,4 @@ hex = { version = "0.4.3" } For the guest program to build successfully add the following to your `.toml` file: ```toml -[app_vm_config.sha256] +[app_vm_config.sha2] diff --git a/docs/vocs/docs/pages/specs/architecture/circuit-architecture.mdx b/docs/vocs/docs/pages/specs/architecture/circuit-architecture.mdx index 6417d8a234..e0f8e9ac89 100644 --- a/docs/vocs/docs/pages/specs/architecture/circuit-architecture.mdx +++ b/docs/vocs/docs/pages/specs/architecture/circuit-architecture.mdx @@ -107,7 +107,7 @@ The chips that fall into these categories are: | FriReducedOpeningChip | – | – | Case 1. | | NativePoseidon2Chip | – | – | Case 1. | | Rv32HintStoreChip | – | – | Case 1. | -| Sha256VmChip | – | – | Case 1. | +| Sha2VmChip | – | – | Case 1. | The PhantomChip satisfies the condition because `1 < 3`. diff --git a/docs/vocs/docs/pages/specs/architecture/continuations.mdx b/docs/vocs/docs/pages/specs/architecture/continuations.mdx index edc8a80085..dec5ef02db 100644 --- a/docs/vocs/docs/pages/specs/architecture/continuations.mdx +++ b/docs/vocs/docs/pages/specs/architecture/continuations.mdx @@ -309,8 +309,13 @@ The `PersistentBoundaryChip` has rows of the form `(expand_direction, address_space, leaf_label, values, hash, timestamp)` and has the following interactions on the `MERKLE_BUS`: +<<<<<<< HEAD:docs/specs/continuations.md +- Send **(1, 0, (as - ADDR_SPACE_OFFSET) \* 2^L, node\*label, hash_initial)** +- Receive **(-1, 0, (as - ADDR_SPACE_OFFSET) \* 2^L, node_label, hash_final)** +======= - Send `(1, 0, (as - ADDR_SPACE_OFFSET) \* 2^L, node\*label, hash_initial)` - Receive `(-1, 0, (as - ADDR_SPACE_OFFSET) \* 2^L, node_label, hash_final)` +>>>>>>> main:docs/vocs/docs/pages/specs/architecture/continuations.mdx It receives `values` from the `MEMORY_BUS` and constrains `hash = compress(values, 0)` via the `POSEIDON2_DIRECT_BUS`. The aggregation program takes a variable number of consecutive segment proofs and consolidates them into a single proof diff --git a/docs/vocs/docs/pages/specs/openvm/isa.mdx b/docs/vocs/docs/pages/specs/openvm/isa.mdx index 14b71fa05c..c0198acaa6 100644 --- a/docs/vocs/docs/pages/specs/openvm/isa.mdx +++ b/docs/vocs/docs/pages/specs/openvm/isa.mdx @@ -6,7 +6,7 @@ This specification describes the overall architecture and default VM extensions - [RV32IM](#rv32im-extension): An extension supporting the 32-bit RISC-V ISA with multiplication. - [Native](#native-extension): An extension supporting native field arithmetic for proof recursion and aggregation. - [Keccak-256](#keccak-extension): An extension implementing the Keccak-256 hash function compatibly with RISC-V memory. -- [SHA2-256](#sha2-256-extension): An extension implementing the SHA2-256 hash function compatibly with RISC-V memory. +- [SHA2](#sha-2-extension): An extension implementing the SHA-256, SHA-512, and SHA-384 hash functions compatibly with RISC-V memory. - [BigInt](#bigint-extension): An extension supporting 256-bit signed and unsigned integer arithmetic, including multiplication. This extension respects the RISC-V memory format. - [Algebra](#algebra-extension): An extension supporting modular arithmetic over arbitrary fields and their complex @@ -32,7 +32,11 @@ OpenVM depends on the following parameters, some of which are fixed and some of | `DEFAULT_PC_STEP` | The default program counter step size. | Fixed to 4. | | `LIMB_BITS` | The number of bits in a limb for RISC-V memory emulation. | Fixed to 8. | | `ADDR_SPACE_OFFSET` | The index of the first writable address space. | Fixed to 1. | +<<<<<<< HEAD:docs/specs/ISA.md +| `addr_space_height` | The base 2 log of the number of writable address spaces supported. | Configurable, must satisfy `addr_space_height <= F::bits() - 2` | +======= | `addr_space_height` | The base 2 log of the number of writable address spaces supported. | Configurable, must satisfy `addr_space_height <= F::bits() - 2` | +>>>>>>> main:docs/vocs/docs/pages/specs/openvm/isa.mdx | `pointer_max_bits` | The maximum number of bits in a pointer. | Configurable, must satisfy `pointer_max_bits <= F::bits() - 2` | | `num_public_values` | The number of user public values. | Configurable. If continuation is enabled, it must equal `8` times a power of two(which is nonzero). | @@ -547,14 +551,16 @@ all memory cells are constrained to be bytes. | -------------- | ----------- | ----------------------------------------------------------------------------------------------------------------- | | KECCAK256_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = keccak256([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Performs memory accesses with block size `4`. | -### SHA2-256 Extension +### SHA-2 Extension -The SHA2-256 extension supports the SHA2-256 hash function. The extension operates on address spaces `1` and `2`, +The SHA-2 extension supports the SHA-256 and SHA-512 hash functions. The extension operates on address spaces `1` and `2`, meaning all memory cells are constrained to be bytes. | Name | Operands | Description | | ----------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | SHA256_RV32 | `a,b,c,1,2` | `[r32{0}(a):32]_2 = sha256([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `16` and writes with block size `32`. | +| SHA512_RV32 | `a,b,c,1,2` | `[r32{0}(a):64]_2 = sha512([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `32` and writes with block size `32`. | +| SHA384_RV32 | `a,b,c,1,2` | `[r32{0}(a):64]_2 = sha384([r32{0}(b)..r32{0}(b)+r32{0}(c)]_2)`. Does the necessary padding. Performs memory reads with block size `32` and writes with block size `32`. Writes 64 bytes to memory: the first 48 are the SHA-384 digest and the last 16 are zeros. | ### BigInt Extension diff --git a/docs/vocs/docs/pages/specs/reference/instruction-reference.mdx b/docs/vocs/docs/pages/specs/reference/instruction-reference.mdx index fa44faebb0..9781522f25 100644 --- a/docs/vocs/docs/pages/specs/reference/instruction-reference.mdx +++ b/docs/vocs/docs/pages/specs/reference/instruction-reference.mdx @@ -130,13 +130,15 @@ In the tables below, we provide the mapping between the `LocalOpcode` and `Phant | ------------- | ---------- | ------------- | | Keccak | `Rv32KeccakOpcode::KECCAK256` | KECCAK256_RV32 | -## SHA2-256 Extension +## SHA-2 Extension #### Instructions | VM Extension | `LocalOpcode` | ISA Instruction | | ------------- | ---------- | ------------- | -| SHA2-256 | `Rv32Sha256Opcode::SHA256` | SHA256_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA256` | SHA256_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA512` | SHA512_RV32 | +| SHA-2 | `Rv32Sha2Opcode::SHA384` | SHA384_RV32 | ## BigInt Extension diff --git a/docs/vocs/docs/pages/specs/reference/riscv-custom-code.mdx b/docs/vocs/docs/pages/specs/reference/riscv-custom-code.mdx index ebbb17c5f4..7adec11327 100644 --- a/docs/vocs/docs/pages/specs/reference/riscv-custom-code.mdx +++ b/docs/vocs/docs/pages/specs/reference/riscv-custom-code.mdx @@ -5,7 +5,7 @@ The default VM extensions that support transpilation are: - [RV32IM](#rv32im-extension): An extension supporting the 32-bit RISC-V ISA with multiplication. - [Keccak-256](#keccak-extension): An extension implementing the Keccak-256 hash function compatibly with RISC-V memory. -- [SHA2-256](#sha2-256-extension): An extension implementing the SHA2-256 hash function compatibly with RISC-V memory. +- [SHA2](#sha-2-extension): An extension implementing the SHA-256, SHA-512, and SHA-384 hash functions compatibly with RISC-V memory. - [BigInt](#bigint-extension): An extension supporting 256-bit signed and unsigned integer arithmetic, including multiplication. This extension respects the RISC-V memory format. - [Algebra](#algebra-extension): An extension supporting modular arithmetic over arbitrary fields and their complex field extensions. This extension respects the RISC-V memory format. - [Elliptic curve](#elliptic-curve-extension): An extension for elliptic curve operations over Weierstrass curves, including addition and doubling. This can be used to implement multi-scalar multiplication and ECDSA scalar multiplication. This extension respects the RISC-V memory format. @@ -85,11 +85,13 @@ implementation is here. But we use `funct3 = 111` because the native extension h | ----------- | --- | ----------- | ------ | ------ | ------------------------------------------- | | keccak256 | R | 0001011 | 100 | 0x0 | `[rd:32]_2 = keccak256([rs1..rs1 + rs2]_2)` | -## SHA2-256 Extension +## SHA-2 Extension | RISC-V Inst | FMT | opcode[6:0] | funct3 | funct7 | RISC-V description and notes | | ----------- | --- | ----------- | ------ | ------ | ---------------------------------------- | | sha256 | R | 0001011 | 100 | 0x1 | `[rd:32]_2 = sha256([rs1..rs1 + rs2]_2)` | +| sha512 | R | 0001011 | 100 | 0x2 | `[rd:64]_2 = sha512([rs1..rs1 + rs2]_2)` | +| sha384 | R | 0001011 | 100 | 0x3 | `[rd:64]_2 = sha384([rs1..rs1 + rs2]_2)`. Last 16 bytes will be set to zeros. | ## BigInt Extension diff --git a/docs/vocs/docs/pages/specs/reference/transpiler.mdx b/docs/vocs/docs/pages/specs/reference/transpiler.mdx index bb16d6fd5c..00ff53e2af 100644 --- a/docs/vocs/docs/pages/specs/reference/transpiler.mdx +++ b/docs/vocs/docs/pages/specs/reference/transpiler.mdx @@ -151,11 +151,13 @@ Each VM extension's behavior is specified below. | ----------- | -------------------------------------------------- | | keccak256 | KECCAK256_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | -### SHA2-256 Extension +### SHA-2 Extension | RISC-V Inst | OpenVM Instruction | | ----------- | ----------------------------------------------- | | sha256 | SHA256_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| sha512 | SHA512_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | +| sha384 | SHA384_RV32 `ind(rd), ind(rs1), ind(rs2), 1, 2` | ### BigInt Extension diff --git a/examples/sha256/src/main.rs b/examples/sha2/src/main.rs similarity index 100% rename from examples/sha256/src/main.rs rename to examples/sha2/src/main.rs diff --git a/examples/sha256/Cargo.toml b/examples/sha256/Cargo.toml deleted file mode 100644 index 0b5a44bc3e..0000000000 --- a/examples/sha256/Cargo.toml +++ /dev/null @@ -1,22 +0,0 @@ -[package] -name = "sha256-example" -version = "0.0.0" -edition = "2021" - -[workspace] -members = [] - -[dependencies] -openvm = { git = "https://github.com/openvm-org/openvm.git", features = [ - "std", -] } -openvm-sha2 = { git = "https://github.com/openvm-org/openvm.git" } -hex = { version = "0.4.3" } - -[features] -default = [] - -# remove this if copying example outside of monorepo -[patch."https://github.com/openvm-org/openvm.git"] -openvm = { path = "../../crates/toolchain/openvm" } -openvm-sha2 = { path = "../../guest-libs/sha2" } diff --git a/examples/sha256/openvm.toml b/examples/sha256/openvm.toml deleted file mode 100644 index 656bf52414..0000000000 --- a/examples/sha256/openvm.toml +++ /dev/null @@ -1,4 +0,0 @@ -[app_vm_config.rv32i] -[app_vm_config.rv32m] -[app_vm_config.io] -[app_vm_config.sha256] diff --git a/extensions/native/circuit/tests/array.rs b/extensions/native/circuit/tests/array.rs index 9ef5eca8ea..cd26b3ce61 100644 --- a/extensions/native/circuit/tests/array.rs +++ b/extensions/native/circuit/tests/array.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmConfig}, ir::{Array, Config, Ext, Felt, RVar, Usize, Var}, @@ -104,7 +104,7 @@ fn test_array_eq() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[should_panic] @@ -125,7 +125,7 @@ fn test_array_eq_neg() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -161,7 +161,7 @@ fn test_slice_variable_impl_happy_path() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -183,5 +183,5 @@ fn test_slice_assert_eq_neg() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/circuit/tests/conditionals.rs b/extensions/native/circuit/tests/conditionals.rs index 29fa85386a..b6ab8cf8ad 100644 --- a/extensions/native/circuit/tests/conditionals.rs +++ b/extensions/native/circuit/tests/conditionals.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{asm::AsmBuilder, ir::Var}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -50,7 +50,7 @@ fn test_compiler_conditionals() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -79,7 +79,7 @@ fn test_compiler_conditionals_v2() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] diff --git a/extensions/native/circuit/tests/ext.rs b/extensions/native/circuit/tests/ext.rs index 5da70cb53b..70494bb6bd 100644 --- a/extensions/native/circuit/tests/ext.rs +++ b/extensions/native/circuit/tests/ext.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::AsmBuilder, ir::{Ext, Felt}, @@ -31,7 +31,7 @@ fn test_ext2felt() { let program = builder.compile_isa(); println!("{}", program); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -60,5 +60,5 @@ fn test_ext_from_base_slice() { let program = builder.compile_isa(); println!("{}", program); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/circuit/tests/for_loops.rs b/extensions/native/circuit/tests/for_loops.rs index 123a416cdb..709105ee32 100644 --- a/extensions/native/circuit/tests/for_loops.rs +++ b/extensions/native/circuit/tests/for_loops.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmConfig}, ir::{Array, Var}, @@ -46,7 +46,7 @@ fn test_compiler_for_loops() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -83,7 +83,7 @@ fn test_compiler_zip_fixed() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -125,7 +125,7 @@ fn test_compiler_zip_dyn() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -162,7 +162,7 @@ fn test_compiler_nested_array_loop() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -182,5 +182,5 @@ fn test_compiler_bneinc() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/circuit/tests/fri_ro_eval.rs b/extensions/native/circuit/tests/fri_ro_eval.rs index 6f332d22b6..dcba950a08 100644 --- a/extensions/native/circuit/tests/fri_ro_eval.rs +++ b/extensions/native/circuit/tests/fri_ro_eval.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, conversion::{convert_program, CompilerOptions}, @@ -89,5 +89,5 @@ fn test_single_reduced_opening_eval() { let asm_code = compiler.code(); let program = convert_program::(asm_code, CompilerOptions::default()); - execute_program(program, vec![mat_opening]); + test_execute_program(program, vec![mat_opening]); } diff --git a/extensions/native/circuit/tests/hint.rs b/extensions/native/circuit/tests/hint.rs index 05aff8a390..5ac1722577 100644 --- a/extensions/native/circuit/tests/hint.rs +++ b/extensions/native/circuit/tests/hint.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{asm::AsmBuilder, ir::Felt}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, Field, FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -29,5 +29,5 @@ fn test_hint_bits_felt() { let program = builder.compile_isa(); println!("{}", program); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/circuit/tests/io.rs b/extensions/native/circuit/tests/io.rs index 58b2c46b8f..ab16bc276a 100644 --- a/extensions/native/circuit/tests/io.rs +++ b/extensions/native/circuit/tests/io.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmCompiler}, conversion::{convert_program, CompilerOptions}, @@ -61,5 +61,5 @@ fn test_io() { println!("{}", asm_code); let program = convert_program::(asm_code, CompilerOptions::default()); - execute_program(program, witness_stream); + test_execute_program(program, witness_stream); } diff --git a/extensions/native/circuit/tests/poseidon2.rs b/extensions/native/circuit/tests/poseidon2.rs index de7dbc6e9b..b3dee9f4c1 100644 --- a/extensions/native/circuit/tests/poseidon2.rs +++ b/extensions/native/circuit/tests/poseidon2.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::AsmBuilder, ir::{Array, Var, PERMUTATION_WIDTH}, @@ -49,7 +49,7 @@ fn test_compiler_poseidon2_permute() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -80,5 +80,5 @@ fn test_compiler_poseidon2_hash_1() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/circuit/tests/range_check.rs b/extensions/native/circuit/tests/range_check.rs index 959f2bae9f..f1e7d9948f 100644 --- a/extensions/native/circuit/tests/range_check.rs +++ b/extensions/native/circuit/tests/range_check.rs @@ -1,4 +1,4 @@ -use openvm_native_circuit::execute_program; +use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{asm::AsmBuilder, prelude::*}; use openvm_stark_backend::p3_field::{extension::BinomialExtensionField, FieldAlgebra}; use openvm_stark_sdk::p3_baby_bear::BabyBear; @@ -24,7 +24,7 @@ fn test_range_check_v() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] @@ -38,5 +38,5 @@ fn test_range_check_v_neg() { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } diff --git a/extensions/native/recursion/src/challenger/duplex.rs b/extensions/native/recursion/src/challenger/duplex.rs index 7c0cd4dd88..309a95f766 100644 --- a/extensions/native/recursion/src/challenger/duplex.rs +++ b/extensions/native/recursion/src/challenger/duplex.rs @@ -189,7 +189,7 @@ impl ChallengerVariable for DuplexChallengerVariable { #[cfg(test)] mod tests { - use openvm_native_circuit::execute_program; + use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::{AsmBuilder, AsmConfig}, ir::Felt, @@ -241,7 +241,7 @@ mod tests { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] diff --git a/extensions/native/recursion/src/fri/domain.rs b/extensions/native/recursion/src/fri/domain.rs index cdc8fc242c..8b6aceefb7 100644 --- a/extensions/native/recursion/src/fri/domain.rs +++ b/extensions/native/recursion/src/fri/domain.rs @@ -156,7 +156,7 @@ where #[cfg(test)] pub(crate) mod tests { - use openvm_native_circuit::execute_program; + use openvm_native_circuit::test_execute_program; use openvm_native_compiler::asm::AsmBuilder; use openvm_stark_backend::{ config::{Domain, StarkGenericConfig, Val}, @@ -276,7 +276,7 @@ pub(crate) mod tests { builder.halt(); let program = builder.compile_isa(); - execute_program(program, vec![]); + test_execute_program(program, vec![]); } #[test] fn test_domain_static() { diff --git a/extensions/native/recursion/src/fri/two_adic_pcs.rs b/extensions/native/recursion/src/fri/two_adic_pcs.rs index 3e66e05e61..470ec67889 100644 --- a/extensions/native/recursion/src/fri/two_adic_pcs.rs +++ b/extensions/native/recursion/src/fri/two_adic_pcs.rs @@ -746,6 +746,6 @@ pub mod tests { #[test] fn test_two_adic_fri_pcs_single_batch() { let (program, witness) = build_test_fri_with_cols_and_log2_rows(10, 10); - openvm_native_circuit::execute_program(program, witness); + openvm_native_circuit::test_execute_program(program, witness); } } diff --git a/extensions/native/recursion/src/hints.rs b/extensions/native/recursion/src/hints.rs index b65f6ba647..642d3d5379 100644 --- a/extensions/native/recursion/src/hints.rs +++ b/extensions/native/recursion/src/hints.rs @@ -446,7 +446,7 @@ impl Hintable for Commitments> { #[cfg(test)] mod test { - use openvm_native_circuit::execute_program; + use openvm_native_circuit::test_execute_program; use openvm_native_compiler::{ asm::AsmBuilder, ir::{Ext, Felt, Var}, @@ -480,7 +480,7 @@ mod test { builder.halt(); let program = builder.compile_isa(); - execute_program(program, stream); + test_execute_program(program, stream); } #[test] @@ -527,6 +527,6 @@ mod test { builder.halt(); let program = builder.compile_isa(); - execute_program(program, stream); + test_execute_program(program, stream); } } diff --git a/extensions/sha256/circuit/Cargo.toml b/extensions/sha2/circuit/Cargo.toml similarity index 79% rename from extensions/sha256/circuit/Cargo.toml rename to extensions/sha2/circuit/Cargo.toml index 740a4302f5..8a34b93f88 100644 --- a/extensions/sha256/circuit/Cargo.toml +++ b/extensions/sha2/circuit/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "openvm-sha256-circuit" +name = "openvm-sha2-circuit" version.workspace = true authors.workspace = true edition.workspace = true -description = "OpenVM circuit extension for sha256" +description = "OpenVM circuit extension for SHA-2" [dependencies] openvm-stark-backend = { workspace = true } @@ -11,20 +11,25 @@ openvm-stark-sdk = { workspace = true } openvm-cuda-backend = { workspace = true, optional = true } openvm-cuda-common = { workspace = true, optional = true } openvm-circuit-primitives = { workspace = true } +openvm-circuit-primitives-derive = { workspace = true } openvm-circuit-derive = { workspace = true } openvm-circuit = { workspace = true } openvm-instructions = { workspace = true } -openvm-sha256-transpiler = { workspace = true } +openvm-sha2-transpiler = { workspace = true } openvm-rv32im-circuit = { workspace = true } -openvm-sha256-air = { workspace = true } + +openvm-sha2-air = { workspace = true } derive-new.workspace = true derive_more = { workspace = true, features = ["from"] } rand.workspace = true serde.workspace = true -sha2 = { version = "0.10", default-features = false } +sha2 = { version = "0.10", features = ["compress"] } +ndarray = { workspace = true, default-features = false } strum = { workspace = true } cfg-if.workspace = true +num_enum = { workspace = true } +itertools = { workspace = true } [dev-dependencies] hex = { workspace = true } @@ -61,3 +66,6 @@ touchemall = [ "openvm-cuda-common/touchemall", "openvm-rv32im-circuit/touchemall", ] + +[package.metadata.cargo-shear] +ignored = ["ndarray"] \ No newline at end of file diff --git a/extensions/sha256/circuit/README.md b/extensions/sha2/circuit/README.md similarity index 56% rename from extensions/sha256/circuit/README.md rename to extensions/sha2/circuit/README.md index 1e794cd35c..de2100b261 100644 --- a/extensions/sha256/circuit/README.md +++ b/extensions/sha2/circuit/README.md @@ -1,28 +1,43 @@ -# SHA256 VM Extension +# SHA-2 VM Extension -This crate contains the circuit for the SHA256 VM extension. +This crate contains circuits for the SHA-2 family of hash functions. +We support SHA-256, SHA-512, and SHA-384. -## SHA-256 Algorithm Summary +## SHA-2 Algorithms Summary -See the [FIPS standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf), in particular, section 6.2 for reference. +The SHA-256, SHA-512, and SHA-384 algorithms are similar in structure. +We will first describe the SHA-256 algorithm, and then describe the differences between the three algorithms. + +See the [FIPS standard](https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf) for reference. In particular, sections 6.2, 6.4, and 6.5. In short the SHA-256 algorithm works as follows. 1. Pad the message to 512 bits and split it into 512-bit 'blocks'. -2. Initialize a hash state consisting of eight 32-bit words. +2. Initialize a hash state consisting of eight 32-bit words to a specific constant value. 3. For each block, - 1. split the message into 16 32-bit words and produce 48 more 'message schedule' words based on them. - 2. apply 64 'rounds' to update the hash state based on the message schedule. - 3. add the previous block's final hash state to the current hash state (modulo `2^32`). + 1. split the message into 16 32-bit words and produce 48 more words based on them. The 16 message words together with the 48 additional words are called the 'message schedule'. + 2. apply a scrambling function 64 times to the hash state to update it based on the message schedule. We call each update a 'round'. + 3. add the previous block's final hash state to the current hash state (modulo $2^{32}$). 4. The output is the final hash state +The differences with the SHA-512 algorithm are that: +- SHA-512 uses 64-bit words, 1024-bit blocks, performs 80 rounds, and produces a 512-bit output. +- all the arithmetic is done modulo $2^{64}$. +- the initial hash state is different. + +The SHA-384 algorithm is a truncation of the SHA-512 output to 384 bits, and the only difference is that the initial hash state is different. + ## Design Overview -This chip produces an AIR that consists of 17 rows for each block (512 bits) in the message, and no more rows. -The first 16 rows of each block are called 'round rows', and each of them represents four rounds of the SHA-256 algorithm. -Each row constrains updates to the working variables on each round, and it also constrains the message schedule words based on previous rounds. -The final row is called a 'digest row' and it produces a final hash for the block, computed as the sum of the working variables and the previous block's final hash. +We reuse the same AIR code to produce circuits for all three algorithms. +To achieve this, we parameterize the AIR by constants (such as the word size, number of rounds, and block size) that are specific to each algorithm. + +This chip produces an AIR that consists of $R+1$ rows for each block of the message, and no more rows +(for SHA-256, $R = 16$ and for SHA-512 and SHA-384, $R = 20$). +The first $R$ rows of each block are called 'round rows', and each of them constrains four rounds of the hash algorithm. +Each row constrains updates to the working variables on each round, and also constrains the message schedule words based on previous rounds. +The final row of each block is called a 'digest row' and it produces a final hash for the block, computed as the sum of the working variables and the previous block's final hash. -Note that this chip only supports messages of length less than `2^29` bytes. +Note that this chip only supports messages of length less than $2^{29}$ bytes. ### Storing working variables @@ -50,7 +65,7 @@ Since we can reliably constrain values from four rounds ago, we can build up `in The last block of every message should have the `is_last_block` flag set to `1`. Note that `is_last_block` is not constrained to be true for the last block of every message, instead it *defines* what the last block of a message is. -For instance, if we produce an air with 10 blocks and only the last block has `is_last_block = 1` then the constraints will interpret it as a single message of length 10 blocks. +For instance, if we produce a trace with 10 blocks and only the last block has `is_last_block = 1` then the constraints will interpret it as a single message of length 10 blocks. If, however, we set `is_last_block` to true for the 6th block, the trace will be interpreted as hashing two messages, each of length 5 blocks. Note that we do constrain, however, that the very last block of the trace has `is_last_block = 1`. @@ -63,11 +78,11 @@ We use this trick in several places in this chip. ### Block index counter variables -There are two "block index" counter variables in each row of the air named `global_block_idx` and `local_block_idx`. -Both of these variables take on the same value on all 17 rows in a block. +There are two "block index" counter variables in each row named `global_block_idx` and `local_block_idx`. +Both of these variables take on the same value on all $R+1$ rows in a block. The `global_block_idx` is the index of the block in the entire trace. -The very first 17 rows in the trace will have `global_block_idx = 1` and the counter will increment by 1 between blocks. +The very first block in the trace will have `global_block_idx = 1` on each row and the counter will increment by 1 between blocks. The padding rows will all have `global_block_idx = 0`. The `global_block_idx` is used in interaction constraints to constrain the value of `hash` between blocks. @@ -79,15 +94,16 @@ The `local_block_idx` is used to calculate the length of the message processed s ### VM air vs SubAir -The SHA-256 VM extension chip uses the `Sha256Air` SubAir to help constrain the SHA-256 hash. -The VM extension air constrains the correctness of the SHA message padding, while the SubAir adds all other constraints related to the hash algorithm. -The VM extension air also constrains memory reads and writes. +The SHA-2 VM extension chip uses the `Sha2Air` SubAir to help constrain the appropriate SHA-2 hash algorithm. +The SubAir is also parameterized by the specific SHA-2 variant's constants. +The VM extension AIR constrains the correctness of the message padding, while the SubAir adds all other constraints related to the hash algorithm. +The VM extension AIR also constrains memory reads and writes. ### A gotcha about padding rows There are two senses of the word padding used in the context of this chip and this can be confusing. -First, we use padding to refer to the extra bits added to the message that is input to the SHA-256 algorithm in order to make the input's length a multiple of 512 bits. -So, we may use the term 'padding rows' to refer to round rows that correspond to the padded bits of a message (as in `Sha256VmAir::eval_padding_row`). +First, we use padding to refer to the extra bits added to the message that is input to the hash algorithm in order to make the input's length a multiple of the block size. +So, we may use the term 'padding rows' to refer to round rows that correspond to the padded bits of a message (as in `Sha2VmAir::eval_padding_row`). Second, the dummy rows that are added to the trace to make the trace height a power of 2 are also called padding rows (see the `is_padding_row` flag). In the SubAir, padding row probably means dummy row. -In the VM air, it probably refers to SHA-256 padding. \ No newline at end of file +In the VM air, it probably refers to the message padding. \ No newline at end of file diff --git a/extensions/sha256/circuit/build.rs b/extensions/sha2/circuit/build.rs similarity index 74% rename from extensions/sha256/circuit/build.rs rename to extensions/sha2/circuit/build.rs index bcb991c8b9..08ed3d4678 100644 --- a/extensions/sha256/circuit/build.rs +++ b/extensions/sha2/circuit/build.rs @@ -11,14 +11,14 @@ fn main() { let builder: CudaBuilder = CudaBuilder::new() .include_from_dep("DEP_CUDA_COMMON_INCLUDE") .include("../../../crates/circuits/primitives/cuda/include") - .include("../../../crates/circuits/sha256-air/cuda/include") + .include("../../../crates/circuits/sha2-air/cuda/include") .include("../../../crates/vm/cuda/include") .watch("cuda") .watch("../../../crates/circuits/primitives/cuda") - .watch("../../../crates/circuits/sha256-air/cuda") + .watch("../../../crates/circuits/sha2-air/cuda") .watch("../../../crates/vm/cuda") - .library_name("tracegen_gpu_sha256") - .file("cuda/src/sha256.cu"); + .library_name("tracegen_gpu_sha2") + .file("cuda/src/sha2.cu"); builder.emit_link_directives(); builder.build(); diff --git a/crates/circuits/sha256-air/cuda/include/sha256-air/columns.cuh b/extensions/sha2/circuit/cuda/include/block_hasher/columns.cuh similarity index 100% rename from crates/circuits/sha256-air/cuda/include/sha256-air/columns.cuh rename to extensions/sha2/circuit/cuda/include/block_hasher/columns.cuh diff --git a/crates/circuits/sha256-air/cuda/include/sha256-air/tracegen.cuh b/extensions/sha2/circuit/cuda/include/block_hasher/tracegen.cuh similarity index 100% rename from crates/circuits/sha256-air/cuda/include/sha256-air/tracegen.cuh rename to extensions/sha2/circuit/cuda/include/block_hasher/tracegen.cuh diff --git a/crates/circuits/sha256-air/cuda/include/sha256-air/utils.cuh b/extensions/sha2/circuit/cuda/include/block_hasher/utils.cuh similarity index 100% rename from crates/circuits/sha256-air/cuda/include/sha256-air/utils.cuh rename to extensions/sha2/circuit/cuda/include/block_hasher/utils.cuh diff --git a/extensions/sha256/circuit/cuda/src/sha256.cu b/extensions/sha2/circuit/cuda/src/sha2.cu similarity index 99% rename from extensions/sha256/circuit/cuda/src/sha256.cu rename to extensions/sha2/circuit/cuda/src/sha2.cu index d1939a8c9a..c10ef4ae98 100644 --- a/extensions/sha256/circuit/cuda/src/sha256.cu +++ b/extensions/sha2/circuit/cuda/src/sha2.cu @@ -1,15 +1,15 @@ #include "launcher.cuh" #include "primitives/constants.h" #include "primitives/trace_access.h" -#include "sha256-air/columns.cuh" -#include "sha256-air/tracegen.cuh" -#include "sha256-air/utils.cuh" +#include "sha2-block-hasher/columns.cuh" +#include "sha2-block-hashker/tracegen.cuh" +#include "sha2-block-hasher/utils.cuh" #include "system/memory/controller.cuh" #include "system/memory/offline_checker.cuh" #include using namespace riscv; -using namespace sha256; +using namespace sha2; __device__ inline void write_round_padding_flags_encoder( RowSlice row, diff --git a/extensions/sha256/circuit/src/cuda_abi.rs b/extensions/sha2/circuit/src/cuda/cuda_abi.rs similarity index 100% rename from extensions/sha256/circuit/src/cuda_abi.rs rename to extensions/sha2/circuit/src/cuda/cuda_abi.rs diff --git a/extensions/sha2/circuit/src/cuda/mod.rs b/extensions/sha2/circuit/src/cuda/mod.rs new file mode 100644 index 0000000000..744df2d7fb --- /dev/null +++ b/extensions/sha2/circuit/src/cuda/mod.rs @@ -0,0 +1,127 @@ +use std::{iter::repeat_n, sync::Arc}; + +use derive_new::new; +use openvm_circuit::{ + arch::{DenseRecordArena, MultiRowLayout, RecordSeeker}, + utils::next_power_of_two_or_zero, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupChipGPU, var_range::VariableRangeCheckerChipGPU, +}; +use openvm_cuda_backend::{ + base::DeviceMatrix, chip::get_empty_air_proving_ctx, prelude::F, prover_backend::GpuBackend, +}; +use openvm_cuda_common::{copy::MemCopyH2D, d_buffer::DeviceBuffer}; +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_sha256_air::{get_sha256_num_blocks, SHA256_HASH_WORDS, SHA256_ROWS_PER_BLOCK}; +use openvm_stark_backend::{prover::types::AirProvingContext, Chip}; + +use crate::{ + cuda_abi::sha256::{ + sha256_fill_invalid_rows, sha256_first_pass_tracegen, sha256_hash_computation, + sha256_second_pass_dependencies, + }, + Sha256VmMetadata, Sha256VmRecordMut, SHA256VM_WIDTH, +}; + +// ===== SHA256 GPU CHIP IMPLEMENTATION ===== +#[derive(new)] +pub struct Sha256VmChipGpu { + pub range_checker: Arc, + pub bitwise_lookup: Arc>, + pub ptr_max_bits: u32, + pub timestamp_max_bits: u32, +} + +impl Chip for Sha256VmChipGpu { + fn generate_proving_ctx(&self, mut arena: DenseRecordArena) -> AirProvingContext { + let records = arena.allocated_mut(); + if records.is_empty() { + return get_empty_air_proving_ctx::(); + } + + let mut record_offsets = Vec::::new(); + let mut block_to_record_idx = Vec::::new(); + let mut block_offsets = Vec::::new(); + let mut offset_so_far = 0; + let mut num_blocks_so_far: u32 = 0; + + while offset_so_far < records.len() { + record_offsets.push(offset_so_far); + block_offsets.push(num_blocks_so_far); + + let record = RecordSeeker::< + DenseRecordArena, + Sha256VmRecordMut, + MultiRowLayout, + >::get_record_at(&mut offset_so_far, records); + + let num_blocks = get_sha256_num_blocks(record.inner.len); + let record_idx = record_offsets.len() - 1; + + block_to_record_idx.extend(repeat_n(record_idx as u32, num_blocks as usize)); + num_blocks_so_far += num_blocks; + } + + assert_eq!(num_blocks_so_far as usize, block_to_record_idx.len()); + assert_eq!(offset_so_far, records.len()); + assert_eq!(block_offsets.len(), record_offsets.len()); + + let d_records = records.to_device().unwrap(); + let d_record_offsets = record_offsets.to_device().unwrap(); + let d_block_offsets = block_offsets.to_device().unwrap(); + let d_block_to_record_idx = block_to_record_idx.to_device().unwrap(); + + let d_prev_hashes = DeviceBuffer::::with_capacity( + num_blocks_so_far as usize * SHA256_HASH_WORDS, // 8 words per SHA256 hash block + ); + + unsafe { + sha256_hash_computation( + &d_records, + record_offsets.len(), + &d_record_offsets, + &d_block_offsets, + &d_prev_hashes, + num_blocks_so_far, + ) + .expect("Hash computation kernel failed"); + } + + let rows_used = num_blocks_so_far as usize * SHA256_ROWS_PER_BLOCK; + let trace_height = next_power_of_two_or_zero(rows_used); + let d_trace = DeviceMatrix::::with_capacity(trace_height, SHA256VM_WIDTH); + + unsafe { + sha256_first_pass_tracegen( + d_trace.buffer(), + trace_height, + &d_records, + record_offsets.len(), + &d_record_offsets, + &d_block_offsets, + &d_block_to_record_idx, + num_blocks_so_far, + &d_prev_hashes, + self.ptr_max_bits, + &self.range_checker.count, + &self.bitwise_lookup.count, + RV32_CELL_BITS as u32, + self.timestamp_max_bits, + ) + .expect("First pass trace generation failed"); + } + + unsafe { + sha256_fill_invalid_rows(d_trace.buffer(), trace_height, rows_used) + .expect("Invalid rows filling failed"); + } + + unsafe { + sha256_second_pass_dependencies(d_trace.buffer(), trace_height, rows_used) + .expect("Second pass trace generation failed"); + } + + AirProvingContext::simple_no_pis(d_trace) + } +} diff --git a/extensions/sha256/circuit/src/extension/cuda.rs b/extensions/sha2/circuit/src/extension/cuda.rs similarity index 100% rename from extensions/sha256/circuit/src/extension/cuda.rs rename to extensions/sha2/circuit/src/extension/cuda.rs diff --git a/extensions/sha2/circuit/src/extension/mod.rs b/extensions/sha2/circuit/src/extension/mod.rs new file mode 100644 index 0000000000..7b8fd0e2ce --- /dev/null +++ b/extensions/sha2/circuit/src/extension/mod.rs @@ -0,0 +1,271 @@ +use std::{ + result::Result, + sync::{Arc, Mutex}, +}; + +use derive_more::derive::From; +use openvm_circuit::{ + arch::{ + AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge, + ExecutorInventoryBuilder, ExecutorInventoryError, InitFileGenerator, MatrixRecordArena, + RowMajorMatrixArena, SystemConfig, VmBuilder, VmChipComplex, VmCircuitExtension, + VmExecutionExtension, VmProverExtension, + }, + system::{ + memory::SharedMemoryHelper, SystemChipInventory, SystemCpuBuilder, SystemExecutor, + SystemPort, + }, +}; +use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor, VmConfig}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, +}; +use openvm_instructions::LocalOpcode; +use openvm_rv32im_circuit::{ + Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, +}; +use openvm_sha2_air::{Sha256Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_field::PrimeField32, + prover::cpu::{CpuBackend, CpuDevice}, +}; +use openvm_stark_sdk::engine::StarkEngine; +use serde::{Deserialize, Serialize}; + +use crate::{Sha2BlockHasherChip, Sha2BlockHasherVmAir, Sha2MainAir, Sha2MainChip, Sha2VmExecutor}; + +cfg_if::cfg_if! { + if #[cfg(feature = "cuda")] { + mod cuda; + pub use self::cuda::*; + pub use self::cuda::Sha256GpuProverExt as Sha256ProverExt; + pub use self::cuda::Sha256Rv32GpuBuilder as Sha256Rv32Builder; + } else { + pub use self::Sha2CpuProverExt as Sha2ProverExt; + pub use self::Sha2Rv32CpuBuilder as Sha2Rv32Builder; + } +} + +#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] +pub struct Sha2Rv32Config { + #[config(executor = "SystemExecutor")] + pub system: SystemConfig, + #[extension] + pub rv32i: Rv32I, + #[extension] + pub rv32m: Rv32M, + #[extension] + pub io: Rv32Io, + #[extension] + pub sha2: Sha2, +} + +impl Default for Sha2Rv32Config { + fn default() -> Self { + Self { + system: SystemConfig::default(), + rv32i: Rv32I, + rv32m: Rv32M::default(), + io: Rv32Io, + sha2: Sha2, + } + } +} + +// Default implementation uses no init file +impl InitFileGenerator for Sha2Rv32Config {} + +#[derive(Clone)] +pub struct Sha2Rv32CpuBuilder; + +impl VmBuilder for Sha2Rv32CpuBuilder +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + Val: PrimeField32, +{ + type VmConfig = Sha2Rv32Config; + type SystemChipInventory = SystemChipInventory; + type RecordArena = MatrixRecordArena>; + + fn create_chip_complex( + &self, + config: &Sha2Rv32Config, + circuit: AirInventory, + ) -> Result< + VmChipComplex, + ChipInventoryError, + > { + let mut chip_complex = + VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; + let inventory = &mut chip_complex.inventory; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; + VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; + VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha2, inventory)?; + Ok(chip_complex) + } +} + +// =================================== VM Extension Implementation ================================= +#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] +pub struct Sha2; + +#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] +pub enum Sha2Executor { + Sha256(Sha2VmExecutor), + Sha512(Sha2VmExecutor), +} + +impl VmExecutionExtension for Sha2 { + type Executor = Sha2Executor; + + fn extend_execution( + &self, + inventory: &mut ExecutorInventoryBuilder, + ) -> Result<(), ExecutorInventoryError> { + let pointer_max_bits = inventory.pointer_max_bits(); + + let sha256_executor = + Sha2VmExecutor::::new(Rv32Sha2Opcode::CLASS_OFFSET, pointer_max_bits); + inventory.add_executor(sha256_executor, [Rv32Sha2Opcode::SHA256.global_opcode()])?; + + let sha512_executor = + Sha2VmExecutor::::new(Rv32Sha2Opcode::CLASS_OFFSET, pointer_max_bits); + inventory.add_executor(sha512_executor, [Rv32Sha2Opcode::SHA512.global_opcode()])?; + + Ok(()) + } +} + +impl VmCircuitExtension for Sha2 { + fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { + let bitwise_lu = { + let existing_air = inventory.find_air::>().next(); + if let Some(air) = existing_air { + air.bus + } else { + let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); + let air = BitwiseOperationLookupAir::<8>::new(bus); + inventory.add_air(air); + air.bus + } + }; + + // this bus will be used for communication between the block hasher chip and the main chip + let sha2_bus_index = inventory.new_bus_idx(); + + // SHA-256 + let sha256_main_air = Sha2MainAir::::new( + inventory.system().port(), + bitwise_lu, + inventory.pointer_max_bits(), + sha2_bus_index, + Rv32Sha2Opcode::CLASS_OFFSET, + ); + inventory.add_air(sha256_main_air); + + let sha256_block_hasher_air = + Sha2BlockHasherVmAir::::new(bitwise_lu, sha2_bus_index); + inventory.add_air(sha256_block_hasher_air); + + // SHA-512 + let sha512_main_air = Sha2MainAir::::new( + inventory.system().port(), + bitwise_lu, + inventory.pointer_max_bits(), + sha2_bus_index, + Rv32Sha2Opcode::CLASS_OFFSET, + ); + inventory.add_air(sha512_main_air); + + let sha512_block_hasher_air = + Sha2BlockHasherVmAir::::new(bitwise_lu, sha2_bus_index); + inventory.add_air(sha512_block_hasher_air); + + Ok(()) + } +} + +pub struct Sha2CpuProverExt; +// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, +// BitwiseOperationLookupChip) are specific to CpuBackend. +impl VmProverExtension for Sha2CpuProverExt +where + SC: StarkGenericConfig, + E: StarkEngine, PD = CpuDevice>, + RA: RowMajorMatrixArena> + Send + Sync + 'static, + Val: PrimeField32, +{ + fn extend_prover( + &self, + _: &Sha2, + inventory: &mut ChipInventory>, + ) -> Result<(), ChipInventoryError> { + let range_checker = inventory.range_checker()?.clone(); + let timestamp_max_bits = inventory.timestamp_max_bits(); + let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); + let pointer_max_bits = inventory.airs().pointer_max_bits(); + + let bitwise_lu = { + let existing_chip = inventory + .find_chip::>() + .next(); + if let Some(chip) = existing_chip { + chip.clone() + } else { + let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; + let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); + inventory.add_periphery_chip(chip.clone()); + chip + } + }; + + // SHA-256 + inventory.next_air::>()?; + // the arena will be passed to the block hasher chip by the main chip + let records = Arc::new(Mutex::new(None)); + let sha256_block_hasher_chip = Sha2BlockHasherChip::, Sha256Config>::new( + bitwise_lu.clone(), + pointer_max_bits, + mem_helper.clone(), + records.clone(), + ); + inventory.add_periphery_chip(sha256_block_hasher_chip); + + inventory.next_air::>()?; + let sha256_main_chip = Sha2MainChip::, Sha256Config>::new( + records, + bitwise_lu.clone(), + pointer_max_bits, + mem_helper.clone(), + ); + inventory.add_executor_chip(sha256_main_chip); + + // SHA-512 + inventory.next_air::>()?; + // the arena will be passed to the block hasher chip by the main chip + let records = Arc::new(Mutex::new(None)); + let sha512_block_hasher_chip = Sha2BlockHasherChip::, Sha512Config>::new( + bitwise_lu.clone(), + pointer_max_bits, + mem_helper.clone(), + records.clone(), + ); + inventory.add_periphery_chip(sha512_block_hasher_chip); + + inventory.next_air::>()?; + let sha512_main_chip = Sha2MainChip::, Sha512Config>::new( + records, + bitwise_lu.clone(), + pointer_max_bits, + mem_helper.clone(), + ); + inventory.add_executor_chip(sha512_main_chip); + + Ok(()) + } +} diff --git a/extensions/sha2/circuit/src/lib.rs b/extensions/sha2/circuit/src/lib.rs new file mode 100644 index 0000000000..46422b8141 --- /dev/null +++ b/extensions/sha2/circuit/src/lib.rs @@ -0,0 +1,10 @@ +mod sha2_chips; +pub use sha2_chips::*; + +mod extension; +pub use extension::*; + +#[cfg(feature = "cuda")] +mod cuda; +#[cfg(feature = "cuda")] +pub use cuda::*; diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/air.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/air.rs new file mode 100644 index 0000000000..9de0166da4 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/air.rs @@ -0,0 +1,160 @@ +use std::{cmp::max, iter::once, marker::PhantomData}; + +use ndarray::s; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupBus, + encoder::Encoder, + utils::{not, select}, + SubAir, +}; +use openvm_sha2_air::{compose, Sha2BlockHasherSubAir, Sha2DigestColsRef, Sha2RoundColsRef}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use crate::{ + MessageType, Sha2BlockHasherVmConfig, Sha2BlockHasherVmDigestColsRef, + Sha2BlockHasherVmRoundColsRef, INNER_OFFSET, +}; + +pub struct Sha2BlockHasherVmAir { + pub inner: Sha2BlockHasherSubAir, + pub sha2_bus: PermutationCheckBus, +} + +impl Sha2BlockHasherVmAir { + pub fn new(bitwise_lookup_bus: BitwiseOperationLookupBus, sha2_bus_idx: BusIndex) -> Self { + Self { + inner: Sha2BlockHasherSubAir::new(bitwise_lookup_bus), + sha2_bus: PermutationCheckBus::new(sha2_bus_idx), + } + } +} + +impl BaseAirWithPublicValues for Sha2BlockHasherVmAir {} +impl PartitionedBaseAir for Sha2BlockHasherVmAir {} +impl BaseAir for Sha2BlockHasherVmAir { + fn width(&self) -> usize { + C::BLOCK_HASHER_WIDTH + } +} + +impl Air for Sha2BlockHasherVmAir { + fn eval(&self, builder: &mut AB) { + self.inner.eval(builder, INNER_OFFSET); + self.eval_interactions(builder); + self.eval_request_id(builder); + } +} + +impl Sha2BlockHasherVmAir { + fn eval_interactions(&self, builder: &mut AB) { + let main = builder.main(); + let local_slice = main.row_slice(0); + let next_slice = main.row_slice(1); + + let local = Sha2BlockHasherVmDigestColsRef::::from::( + &local_slice[..C::BLOCK_HASHER_DIGEST_WIDTH], + ); + + // Receive (STATE, request_id, prev_state_as_u16s, new_state) on the sha2 bus + self.sha2_bus.receive( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::State as u8), + (*local.request_id).into(), + ] + .into_iter() + .chain(local.inner.prev_hash.flatten().map(|x| (*x).into())) + .chain(local.inner.final_hash.flatten().map(|x| (*x).into())), + *local.inner.flags.is_digest_row, + ); + + let local = Sha2BlockHasherVmRoundColsRef::::from::( + &local_slice[..C::BLOCK_HASHER_ROUND_WIDTH], + ); + let next = Sha2BlockHasherVmRoundColsRef::::from::( + &next_slice[..C::BLOCK_HASHER_ROUND_WIDTH], + ); + + let is_local_first_row = self + .inner + .row_idx_encoder + .contains_flag::(local.inner.flags.row_idx.to_slice().unwrap(), &[0]); + + // Taken from old Sha256VmChip: + // https://github.com/openvm-org/openvm/blob/c2e376e6059c8bbf206736cf01d04cda43dfc42d/extensions/sha256/circuit/src/sha256_chip/air.rs#L310C1-L318C1 + let get_ith_byte = |i: usize, cols: &Sha2BlockHasherVmRoundColsRef| { + debug_assert!(i < C::WORD_U8S * C::ROUNDS_PER_ROW); + let row_idx = i / C::WORD_U8S; + let word: Vec = cols + .inner + .message_schedule + .w + .row(row_idx) + .into_iter() + .copied() + .collect::>(); + // Need to reverse the byte order to match the endianness of the memory + let byte_idx = C::WORD_U8S - i % C::WORD_U8S - 1; + compose::(&word[byte_idx * 8..(byte_idx + 1) * 8], 1) + }; + + let local_message = (0..C::WORD_U8S * C::ROUNDS_PER_ROW).map(|i| get_ith_byte(i, &local)); + let next_message = (0..C::WORD_U8S * C::ROUNDS_PER_ROW).map(|i| get_ith_byte(i, &next)); + + // Receive (MESSAGE_1, request_id, first_half_of_message) on the sha2 bus + self.sha2_bus.receive( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::Message1 as u8), + (*local.request_id).into(), + ] + .into_iter() + .chain(local_message.clone()) + .chain(next_message.clone()), + is_local_first_row * local.inner.flags.is_enabled(), + ); + + let is_local_third_row = self + .inner + .row_idx_encoder + .contains_flag::(local.inner.flags.row_idx.to_slice().unwrap(), &[2]); + + // Send (MESSAGE_2, request_id, second_half_of_message) to the sha2 bus + self.sha2_bus.receive( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::Message2 as u8), + (*local.request_id).into(), + ] + .into_iter() + .chain(local_message) + .chain(next_message), + is_local_third_row * local.inner.flags.is_enabled(), + ); + } + + fn eval_request_id(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let next = main.row_slice(1); + + // doesn't matter if we use round or digest cols here, since we only access + // request_id and inner.flags.is_last block, which are common to both + // field + let local = + Sha2BlockHasherVmRoundColsRef::::from::(&local[..C::BLOCK_HASHER_WIDTH]); + let next = + Sha2BlockHasherVmRoundColsRef::::from::(&next[..C::BLOCK_HASHER_WIDTH]); + + builder + .when_transition() + .when(not(*local.inner.flags.is_digest_row)) + .assert_eq(*next.request_id, *local.request_id); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/columns.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/columns.rs new file mode 100644 index 0000000000..f2159f8dda --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/columns.rs @@ -0,0 +1,59 @@ +use openvm_circuit_primitives_derive::ColsRef; +use openvm_sha2_air::{ + Sha2BlockHasherSubairConfig, Sha2DigestCols, Sha2DigestColsRef, Sha2DigestColsRefMut, + Sha2RoundCols, Sha2RoundColsRef, Sha2RoundColsRefMut, +}; + +// offset in the columns struct where the inner column start +pub const INNER_OFFSET: usize = 1; + +// Just adding request_id to both columns structs +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2BlockHasherVmRoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub request_id: T, + pub inner: Sha2RoundCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2BlockHasherSubairConfig)] +pub struct Sha2BlockHasherVmDigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub request_id: T, + pub inner: Sha2DigestCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + HASH_WORDS, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, +} diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/config.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/config.rs new file mode 100644 index 0000000000..9848e7cf4a --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/config.rs @@ -0,0 +1,52 @@ +use openvm_sha2_air::{Sha256Config, Sha2BlockHasherSubairConfig, Sha384Config, Sha512Config}; + +use crate::{Sha2BlockHasherVmDigestColsRef, Sha2BlockHasherVmRoundColsRef}; + +pub trait Sha2BlockHasherVmConfig: Sha2BlockHasherSubairConfig { + /// Width of the Sha2VmRoundCols + const BLOCK_HASHER_ROUND_WIDTH: usize; + /// Width of the Sha2DigestCols + const BLOCK_HASHER_DIGEST_WIDTH: usize; + /// Width of the Sha2BlockHasherCols + const BLOCK_HASHER_WIDTH: usize; +} + +impl Sha2BlockHasherVmConfig for Sha256Config { + const BLOCK_HASHER_ROUND_WIDTH: usize = + Sha2BlockHasherVmRoundColsRef::::width::(); + const BLOCK_HASHER_DIGEST_WIDTH: usize = + Sha2BlockHasherVmDigestColsRef::::width::(); + const BLOCK_HASHER_WIDTH: usize = + if Self::BLOCK_HASHER_ROUND_WIDTH > Self::BLOCK_HASHER_DIGEST_WIDTH { + Self::BLOCK_HASHER_ROUND_WIDTH + } else { + Self::BLOCK_HASHER_DIGEST_WIDTH + }; +} + +impl Sha2BlockHasherVmConfig for Sha512Config { + const BLOCK_HASHER_ROUND_WIDTH: usize = + Sha2BlockHasherVmRoundColsRef::::width::(); + const BLOCK_HASHER_DIGEST_WIDTH: usize = + Sha2BlockHasherVmDigestColsRef::::width::(); + const BLOCK_HASHER_WIDTH: usize = if ::BLOCK_HASHER_ROUND_WIDTH + > Self::BLOCK_HASHER_DIGEST_WIDTH + { + Self::BLOCK_HASHER_ROUND_WIDTH + } else { + Self::BLOCK_HASHER_DIGEST_WIDTH + }; +} + +impl Sha2BlockHasherVmConfig for Sha384Config { + const BLOCK_HASHER_ROUND_WIDTH: usize = + Sha2BlockHasherVmRoundColsRef::::width::(); + const BLOCK_HASHER_DIGEST_WIDTH: usize = + Sha2BlockHasherVmDigestColsRef::::width::(); + const BLOCK_HASHER_WIDTH: usize = + if Self::BLOCK_HASHER_ROUND_WIDTH > Self::BLOCK_HASHER_DIGEST_WIDTH { + Self::BLOCK_HASHER_ROUND_WIDTH + } else { + Self::BLOCK_HASHER_DIGEST_WIDTH + }; +} diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/mod.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/mod.rs new file mode 100644 index 0000000000..4a19d5011f --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/mod.rs @@ -0,0 +1,63 @@ +mod air; +mod columns; + +mod config; +mod trace; + +use std::{ + cell::Cell, + marker::PhantomData, + sync::{Arc, Mutex}, +}; + +pub use air::*; +pub use columns::*; +pub use config::*; +use openvm_circuit::{ + arch::{RowMajorMatrixArena, VmChipWrapper}, + system::memory::SharedMemoryHelper, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, encoder::Encoder, +}; +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_sha2_air::{Sha2BlockHasherFillerHelper, Sha2BlockHasherSubairConfig}; +use openvm_stark_backend::p3_matrix::dense::RowMajorMatrix; +pub use trace::*; + +pub use super::config::*; + +pub struct Sha2BlockHasherChip { + pub inner: Sha2BlockHasherFillerHelper, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pub pointer_max_bits: usize, + pub mem_helper: SharedMemoryHelper, + // This Arc>> is shared with the main chip (Sha2MainChip). + // When the main chip's tracegen is done, it will set the value of the mutex to Some(records) + // and then the block hasher chip can see the records and use it to generate its trace. + // The arc mutex is not strictly necessary (we could just use a Cell) because tracegen is done + // sequentially over the list of chips (although it is parallelized within each chip), but the + // overhead of using a thread-safe type is negligible since we only access the 'records' field + // twice (once to set the value and once to get the value). + // So, we will just use an arc mutex to avoid overcomplicating things. + pub records: Arc>>>, + _phantom: PhantomData, +} + +impl Sha2BlockHasherChip { + pub fn new( + bitwise_lookup_chip: SharedBitwiseOperationLookupChip, + pointer_max_bits: usize, + mem_helper: SharedMemoryHelper, + records: Arc>>>, + ) -> Self { + Self { + inner: Sha2BlockHasherFillerHelper::new(), + bitwise_lookup_chip, + pointer_max_bits, + mem_helper, + records, + _phantom: PhantomData, + } + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/trace.rs b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/trace.rs new file mode 100644 index 0000000000..4fa9852cf7 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/block_hasher_chip/trace.rs @@ -0,0 +1,280 @@ +use std::{ + array::{self, from_fn}, + borrow::{Borrow, BorrowMut}, + cmp::min, + marker::PhantomData, + mem, + ops::Range, + sync::Arc, +}; + +use itertools::Itertools; +use openvm_circuit::{ + arch::{ + CustomBorrow, MultiRowLayout, MultiRowMetadata, PreflightExecutor, RecordArena, + SizedRecord, VmStateMut, *, + }, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::SharedBitwiseOperationLookupChip, + encoder::Encoder, + utils::{compose, next_power_of_two_or_zero}, + AlignedBytesBorrow, +}; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; +use openvm_sha2_air::{ + be_limbs_into_word, big_sig0, big_sig1, ch, le_limbs_into_word, maj, + set_arrayview_from_u8_slice, word_into_bits, word_into_u16_limbs, Sha2BlockHasherFillerHelper, + Sha2DigestColsRefMut, Sha2RoundColsRef, Sha2RoundColsRefMut, WrappingAdd, +}; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::*, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, +}; + +use crate::{ + Sha2BlockHasherChip, Sha2BlockHasherVmConfig, Sha2BlockHasherVmRoundColsRefMut, Sha2Config, + Sha2Metadata, Sha2RecordLayout, Sha2RecordMut, INNER_OFFSET, +}; + +// We don't use the record arena associated with this chip. Instead, we will use the record arena +// provided by the main chip, which will be passed to this chip after the main chip's tracegen is +// done. +impl Chip> for Sha2BlockHasherChip, C> +where + Val: PrimeField32, + SC: StarkGenericConfig, +{ + fn generate_proving_ctx(&self, _: R) -> AirProvingContext> { + // SAFETY: the tracegen for Sha2MainChip must be done before this chip's tracegen + let mut records = self.records.lock().unwrap(); + let mut records = records.take().unwrap(); + // 1 record per instruction + let num_instructions = records.height(); + let rows_used = num_instructions * C::ROWS_PER_BLOCK; + + let height = next_power_of_two_or_zero(rows_used); + let trace = Val::::zero_vec(height * C::BLOCK_HASHER_WIDTH); + let mut trace_matrix = RowMajorMatrix::new(trace, C::BLOCK_HASHER_WIDTH); + + self.fill_trace(&mut trace_matrix, rows_used, &mut records); + + AirProvingContext::simple(Arc::new(trace_matrix), vec![]) + } +} + +impl Sha2BlockHasherChip +where + F: PrimeField32, + C: Sha2BlockHasherVmConfig, +{ + fn fill_trace( + &self, + trace_matrix: &mut RowMajorMatrix, + rows_used: usize, + records: &mut RowMajorMatrix, + ) { + if rows_used == 0 { + return; + } + + let trace = &mut trace_matrix.values[..]; + + // fill in dummy rows + trace[rows_used * C::BLOCK_HASHER_WIDTH..] + .par_chunks_exact_mut(C::BLOCK_HASHER_WIDTH) + .for_each(|row| { + let cols = Sha2RoundColsRefMut::from::( + &mut row[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH], + ); + + self.inner.generate_default_row(cols); + }); + + // fill in used rows + trace[..rows_used * C::BLOCK_HASHER_WIDTH] + .par_chunks_exact_mut(C::BLOCK_HASHER_WIDTH * C::ROWS_PER_BLOCK) + .zip(records.par_rows_mut()) + .enumerate() + .for_each(|(block_idx, (block_slice, mut record))| { + // SAFETY: + // - caller ensures `records` contains a valid record representation that was + // previously written by the executor + // - records contains a valid Sha2RecordMut with the exact layout specified + // - get_record_from_slice will correctly split the buffer into header, input, and + // aux components based on this layout + let record: Sha2RecordMut = unsafe { + get_record_from_slice( + &mut record, + Sha2RecordLayout { + metadata: Sha2Metadata { + variant: C::VARIANT, + }, + }, + ) + }; + + let prev_hash = (0..C::HASH_WORDS) + .map(|i| { + be_limbs_into_word::( + &record.prev_state[i * C::WORD_U8S..(i + 1) * C::WORD_U8S] + .iter() + .map(|x| *x as u32) + .collect::>(), + ) + }) + .collect::>(); + + self.fill_block_trace( + block_slice, + record.message_bytes, + block_idx + 1, // 1-indexed + &prev_hash, + block_idx, + ); + }); + + // Do a second pass over the trace to fill in the missing values + // Note, we need to skip the very first row + trace[C::BLOCK_HASHER_WIDTH..] + .par_chunks_mut(C::BLOCK_HASHER_WIDTH * C::ROWS_PER_BLOCK) + .take(rows_used / C::ROWS_PER_BLOCK) + .for_each(|chunk| { + self.inner + .generate_missing_cells(chunk, C::BLOCK_HASHER_WIDTH, INNER_OFFSET); + }); + + { + let mut rows_24_and_25 = trace + .chunks_exact_mut(C::BLOCK_HASHER_WIDTH) + .skip(24) + .take(2) + .collect::>(); + let row_24 = rows_24_and_25.remove(0); + let row_25 = rows_24_and_25.remove(0); + let row_24_cols = Sha2RoundColsRef::from::( + &row_24[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH], + ); + let mut row_25_cols = Sha2RoundColsRefMut::from::( + &mut row_25[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH], + ); + + // Sha2BlockHasherFillerHelper::::generate_carry_ae(row_24_cols, &mut row_25_cols); + + println!("row 25 carry_a: {:?}", row_25_cols.work_vars.carry_a); + + println!("row 25 a: {:?}", row_25_cols.work_vars.a); + println!("row 25 e: {:?}", row_25_cols.work_vars.e); + } + + { + let mut rows = trace + .chunks_exact_mut(C::BLOCK_HASHER_WIDTH) + .collect::>(); + let row_0 = rows.remove(0); + let row_31 = rows.remove(30); + let row_31_cols = Sha2RoundColsRef::from::( + &row_31[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH], + ); + let mut row_0_cols = Sha2RoundColsRefMut::from::( + &mut row_0[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH], + ); + + // Sha2BlockHasherFillerHelper::::generate_carry_ae(row_31_cols, &mut row_0_cols); + + println!("row 0 carry_a: {:?}", row_0_cols.work_vars.carry_a); + + println!("row 0 a: {:?}", row_0_cols.work_vars.a); + println!("row 0 e: {:?}", row_0_cols.work_vars.e); + } + } +} + +impl Sha2BlockHasherChip { + #[allow(clippy::too_many_arguments)] + fn fill_block_trace( + &self, + block_slice: &mut [F], + input: &[u8], + global_block_idx: usize, + prev_hash: &[C::Word], + request_id: usize, + ) where + F: PrimeField32, + { + debug_assert_eq!(input.len(), C::BLOCK_U8S); + debug_assert_eq!(prev_hash.len(), C::HASH_WORDS); + + // Set request_id and fill the input into carry_or_buffer + block_slice + .par_chunks_exact_mut(C::BLOCK_HASHER_WIDTH) + .enumerate() + .for_each(|(row_idx, row_slice)| { + // Set request_id + let cols = Sha2BlockHasherVmRoundColsRefMut::::from::( + &mut row_slice[..C::BLOCK_HASHER_WIDTH], + ); + *cols.request_id = F::from_canonical_usize(request_id); + + // Fill the input into carry_or_buffer + if row_idx < C::MESSAGE_ROWS { + let mut round_cols = Sha2RoundColsRefMut::::from::( + row_slice[INNER_OFFSET..INNER_OFFSET + C::SUBAIR_ROUND_WIDTH].borrow_mut(), + ); + // We don't actually need to set carry_or_buffer for the first 4 rows, because + // the subair won't actually use it (we used to store the memory read into here + // in the old SHA-2 chip). We will remove the subair constraints that constrain + // carry_or_buffer for the first 4 rows in the future. Then, we can skip setting + // carry_or_buffer for the first 4 rows here. + // TODO: resolve this before getting this PR reviewed. + set_arrayview_from_u8_slice( + &mut round_cols.message_schedule.carry_or_buffer, + input[row_idx * C::ROUNDS_PER_ROW * C::WORD_U8S + ..(row_idx + 1) * C::ROUNDS_PER_ROW * C::WORD_U8S] + .iter() + .copied(), + ); + } + }); + + let input_words = (0..C::BLOCK_WORDS) + .map(|i| { + be_limbs_into_word::( + &input[i * C::WORD_U8S..(i + 1) * C::WORD_U8S] + .iter() + .map(|x| *x as u32) + .collect::>(), + ) + }) + .collect::>(); + + // Fill in the inner trace + self.inner.generate_block_trace( + block_slice, + C::BLOCK_HASHER_WIDTH, + INNER_OFFSET, + &input_words, + self.bitwise_lookup_chip.clone(), + prev_hash, + true, + global_block_idx as u32, + ); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/config.rs b/extensions/sha2/circuit/src/sha2_chips/config.rs new file mode 100644 index 0000000000..8c260ebe19 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/config.rs @@ -0,0 +1,81 @@ +use itertools::Itertools; +use openvm_sha2_air::{Sha256Config, Sha2BlockHasherSubairConfig, Sha384Config, Sha512Config}; +use sha2::{ + compress256, compress512, digest::generic_array::GenericArray, Digest, Sha256, Sha384, Sha512, +}; + +use crate::{Sha2BlockHasherVmConfig, Sha2MainChipConfig}; + +pub const SHA2_REGISTER_READS: usize = 3; +pub const SHA2_READ_SIZE: usize = 4; +pub const SHA2_WRITE_SIZE: usize = 4; + +pub trait Sha2Config: Sha2MainChipConfig + Sha2BlockHasherVmConfig { + /// Number of bits used to store the message length (part of the message padding) + const MESSAGE_LENGTH_BITS: usize; + + // Preconditions: + // - state.len() >= Self::STATE_BYTES + // - input.len() == Self::BLOCK_BYTES + fn compress(state: &mut [u8], input: &[u8]); + + fn hash(message: &[u8]) -> Vec; +} + +impl Sha2Config for Sha256Config { + const MESSAGE_LENGTH_BITS: usize = 64; + + // TODO: do this faster + fn compress(state: &mut [u8], input: &[u8]) { + debug_assert!(state.len() >= Sha256Config::STATE_BYTES); + debug_assert!(input.len() == Sha256Config::BLOCK_BYTES); + + let state_u32s = state + .chunks_exact(4) + .map(|chunk| u32::from_be_bytes(chunk.try_into().unwrap())) + .collect_vec(); + let mut state_u32s_array = state_u32s.try_into().unwrap(); + // let state: &mut [u32; 8] = unsafe { &mut *(state.as_mut_ptr() as *mut [u32; 8]) }; + let input_array = GenericArray::from_slice(input); + compress256(&mut state_u32s_array, &[*input_array]); + + state.copy_from_slice( + &state_u32s_array + .iter() + .flat_map(|x| x.to_be_bytes()) + .collect_vec(), + ); + } + + fn hash(message: &[u8]) -> Vec { + Sha256::digest(message).to_vec() + } +} + +impl Sha2Config for Sha512Config { + const MESSAGE_LENGTH_BITS: usize = 128; + + fn compress(state: &mut [u8], input: &[u8]) { + debug_assert!(state.len() >= Sha512Config::STATE_BYTES); + debug_assert!(input.len() == Sha512Config::BLOCK_BYTES); + let state: &mut [u64; 8] = unsafe { &mut *(state.as_mut_ptr() as *mut [u64; 8]) }; + let input_array = GenericArray::from_slice(input); + compress512(state, &[*input_array]); + } + + fn hash(message: &[u8]) -> Vec { + Sha512::digest(message).to_vec() + } +} + +impl Sha2Config for Sha384Config { + const MESSAGE_LENGTH_BITS: usize = Sha512Config::MESSAGE_LENGTH_BITS; + + fn compress(state: &mut [u8], input: &[u8]) { + Sha512Config::compress(state, input); + } + + fn hash(message: &[u8]) -> Vec { + Sha384::digest(message).to_vec() + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/execution.rs b/extensions/sha2/circuit/src/sha2_chips/execution.rs new file mode 100644 index 0000000000..f13e77681d --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/execution.rs @@ -0,0 +1,210 @@ +use std::borrow::{Borrow, BorrowMut}; + +use openvm_circuit::{arch::*, system::memory::online::GuestMemory}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS}, + LocalOpcode, +}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; +use openvm_stark_backend::p3_field::PrimeField32; + +use super::{Sha2Config, Sha2VmExecutor, SHA2_READ_SIZE}; +use crate::SHA2_WRITE_SIZE; + +#[derive(AlignedBytesBorrow, Clone)] +#[repr(C)] +struct Sha2PreCompute { + a: u8, + b: u8, + c: u8, +} + +impl Executor for Sha2VmExecutor { + fn pre_compute_size(&self) -> usize { + size_of::() + } + + #[cfg(not(feature = "tco"))] + fn pre_compute( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut Sha2PreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_impl::<_, _, C>) + } + + #[cfg(feature = "tco")] + fn handler( + &self, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: ExecutionCtxTrait, + { + let data: &mut Sha2PreCompute = data.borrow_mut(); + self.pre_compute_impl(pc, inst, data)?; + Ok(execute_e1_handler::<_, _>) + } +} + +impl MeteredExecutor for Sha2VmExecutor { + fn metered_pre_compute_size(&self) -> usize { + size_of::>() + } + + #[cfg(not(feature = "tco"))] + fn metered_pre_compute( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_impl::<_, _, C>) + } + + #[cfg(feature = "tco")] + fn metered_handler( + &self, + chip_idx: usize, + pc: u32, + inst: &Instruction, + data: &mut [u8], + ) -> Result, StaticProgramError> + where + Ctx: MeteredExecutionCtxTrait, + { + let data: &mut E2PreCompute = data.borrow_mut(); + data.chip_idx = chip_idx as u32; + self.pre_compute_impl(pc, inst, &mut data.data)?; + Ok(execute_e2_handler::<_, _, C>) + } +} + +#[inline(always)] +unsafe fn execute_e12_impl< + F: PrimeField32, + C: Sha2Config, + CTX: ExecutionCtxTrait, + const IS_E1: bool, +>( + pre_compute: &Sha2PreCompute, + instret: &mut u64, + pc: &mut u32, + exec_state: &mut VmExecState, +) -> u32 { + let dst = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.a as u32); + let state = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.b as u32); + let input = exec_state.vm_read(RV32_REGISTER_AS, pre_compute.c as u32); + let dst_u32 = u32::from_le_bytes(dst); + let state_u32 = u32::from_le_bytes(state); + let input_u32 = u32::from_le_bytes(input); + + let mut state_data = Vec::with_capacity(C::STATE_BYTES); + let mut input_block = Vec::with_capacity(C::BLOCK_BYTES); + for i in 0..C::STATE_READS { + state_data.extend_from_slice(&exec_state.vm_read::( + RV32_MEMORY_AS, + state_u32 + (i * SHA2_READ_SIZE) as u32, + )); + } + for i in 0..C::BLOCK_READS { + input_block.extend_from_slice(&exec_state.vm_read::( + RV32_MEMORY_AS, + input_u32 + (i * SHA2_READ_SIZE) as u32, + )); + } + + C::compress(&mut state_data, &input_block); + + for i in 0..C::STATE_WRITES { + exec_state.vm_write::( + RV32_MEMORY_AS, + dst_u32 + (i * SHA2_WRITE_SIZE) as u32, + &state_data[i * SHA2_WRITE_SIZE..(i + 1) * SHA2_WRITE_SIZE] + .try_into() + .unwrap(), + ); + } + + *pc = pc.wrapping_add(DEFAULT_PC_STEP); + *instret += 1; + + 1 // height delta +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e1_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _instret_end: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &Sha2PreCompute = pre_compute.borrow(); + execute_e12_impl::(pre_compute, instret, pc, exec_state); +} + +#[create_handler] +#[inline(always)] +unsafe fn execute_e2_impl( + pre_compute: &[u8], + instret: &mut u64, + pc: &mut u32, + _arg: u64, + exec_state: &mut VmExecState, +) { + let pre_compute: &E2PreCompute = pre_compute.borrow(); + let height = execute_e12_impl::(&pre_compute.data, instret, pc, exec_state); + exec_state + .ctx + .on_height_change(pre_compute.chip_idx as usize, height); +} + +impl Sha2VmExecutor { + fn pre_compute_impl( + &self, + pc: u32, + inst: &Instruction, + data: &mut Sha2PreCompute, + ) -> Result<(), StaticProgramError> { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = inst; + let e_u32 = e.as_canonical_u32(); + if d.as_canonical_u32() != RV32_REGISTER_AS || e_u32 != RV32_MEMORY_AS { + return Err(StaticProgramError::InvalidInstruction(pc)); + } + *data = Sha2PreCompute { + a: a.as_canonical_u32() as u8, + b: b.as_canonical_u32() as u8, + c: c.as_canonical_u32() as u8, + }; + assert_eq!(&Rv32Sha2Opcode::SHA256.global_opcode(), opcode); + Ok(()) + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/air.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/air.rs new file mode 100644 index 0000000000..e075815b91 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/air.rs @@ -0,0 +1,334 @@ +use std::marker::PhantomData; + +use itertools::izip; +use ndarray::s; +use openvm_circuit::{ + arch::ExecutionBridge, + system::{ + memory::{offline_checker::MemoryBridge, MemoryAddress}, + SystemPort, + }, +}; +use openvm_circuit_primitives::{bitwise_op_lookup::BitwiseOperationLookupBus, utils::compose}; +use openvm_instructions::riscv::{ + RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS, +}; +use openvm_sha2_air::Sha2BlockHasherSubairConfig; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder, PermutationCheckBus}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::FieldAlgebra, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use super::config::Sha2MainChipConfig; +use crate::{MessageType, Sha2ColsRef, SHA2_READ_SIZE, SHA2_WRITE_SIZE}; + +#[derive(Clone, Debug)] +pub struct Sha2MainAir { + pub execution_bridge: ExecutionBridge, + pub memory_bridge: MemoryBridge, + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + pub sha2_bus: PermutationCheckBus, + /// Maximum number of bits allowed for an address pointer + /// Must be at least 24 + pub ptr_max_bits: usize, + pub offset: usize, + _phantom: PhantomData, +} + +impl Sha2MainAir { + pub fn new( + SystemPort { + execution_bus, + program_bus, + memory_bridge, + }: SystemPort, + bitwise_lookup_bus: BitwiseOperationLookupBus, + ptr_max_bits: usize, + self_bus_idx: BusIndex, + offset: usize, + ) -> Self { + Self { + execution_bridge: ExecutionBridge::new(execution_bus, program_bus), + memory_bridge, + bitwise_lookup_bus, + sha2_bus: PermutationCheckBus::new(self_bus_idx), + ptr_max_bits, + offset, + _phantom: PhantomData, + } + } +} + +impl BaseAirWithPublicValues for Sha2MainAir {} +impl PartitionedBaseAir for Sha2MainAir {} +impl BaseAir for Sha2MainAir { + fn width(&self) -> usize { + C::MAIN_CHIP_WIDTH + } +} + +impl Air + for Sha2MainAir +{ + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + + let local: Sha2ColsRef = Sha2ColsRef::from::(&local[..C::MAIN_CHIP_WIDTH]); + let next: Sha2ColsRef = Sha2ColsRef::from::(&next[..C::MAIN_CHIP_WIDTH]); + + let mut timestamp_delta = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + local.instruction.from_state.timestamp + + AB::F::from_canonical_usize(timestamp_delta - 1) + }; + + self.eval_block(builder, &local, &next); + self.eval_instruction(builder, &local, &mut timestamp_pp); + self.eval_reads(builder, &local, &mut timestamp_pp); + self.eval_writes(builder, &local, &mut timestamp_pp); + } +} + +impl Sha2MainAir { + pub fn eval_block( + &self, + builder: &mut AB, + local: &Sha2ColsRef, + next: &Sha2ColsRef, + ) { + builder + .when_first_row() + .assert_zero(*local.block.request_id); + + builder.when_transition().assert_eq( + *next.block.request_id, + *local.block.request_id + AB::Expr::ONE, + ); + + let prev_state_as_u16s: Vec = local + .block + .prev_state + .exact_chunks(C::WORD_U8S) + .into_iter() + .flat_map(|word| { + word.as_slice() + .unwrap() + .chunks_exact(2) + .rev() + .map(|x| x[0] * AB::F::from_canonical_u64(1 << 8) + x[1]) + .collect::>() + }) + .collect(); + + // for each word in the new state, byte1, byte2, ..., byteN, reverse the order of the bytes + // so that it matches what the block hasher chip expects + let new_state_big_endian: Vec = local + .block + .new_state + .exact_chunks(C::WORD_U8S) + .into_iter() + .flat_map(|word| word.into_iter().rev().copied().collect::>()) + .collect(); + + // Send (STATE, request_id, prev_state_as_u16s, new_state) to the sha2 bus + self.sha2_bus.send( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::State as u8), + (*local.block.request_id).into(), + ] + .into_iter() + .chain(prev_state_as_u16s) + .chain(new_state_big_endian.into_iter().map(|x| x.into())), + *local.instruction.is_enabled, + ); + + // Send (MESSAGE_1, request_id, first_half_of_message) to the sha2 bus + self.sha2_bus.send( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::Message1 as u8), + (*local.block.request_id).into(), + ] + .into_iter() + .chain( + local + .block + .message_bytes + .iter() + .take(C::BLOCK_BYTES / 2) + .map(|x| (*x).into()), + ), + *local.instruction.is_enabled, + ); + + // Send (MESSAGE_2, request_id, second_half_of_message) to the sha2 bus + self.sha2_bus.send( + builder, + [ + AB::Expr::from_canonical_u8(MessageType::Message2 as u8), + (*local.block.request_id).into(), + ] + .into_iter() + .chain( + local + .block + .message_bytes + .iter() + .skip(C::BLOCK_BYTES / 2) + .map(|x| (*x).into()), + ), + *local.instruction.is_enabled, + ); + } + + pub fn eval_instruction( + &self, + builder: &mut AB, + local: &Sha2ColsRef, + timestamp_pp: &mut impl FnMut() -> AB::Expr, + ) { + for (&ptr, val, aux) in izip!( + [ + local.instruction.dst_reg_ptr, + local.instruction.state_reg_ptr, + local.instruction.input_reg_ptr + ], + [ + local.instruction.dst_ptr_limbs, + local.instruction.state_ptr_limbs, + local.instruction.input_ptr_limbs + ], + &local.mem.register_aux, + ) { + self.memory_bridge + .read::<_, _, SHA2_READ_SIZE>( + MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_REGISTER_AS), ptr), + val.to_vec().try_into().unwrap_or_else(|_| panic!()), // can't unwrap because AB::Var doesn't impl Debug + timestamp_pp(), + aux, + ) + .eval(builder, *local.instruction.is_enabled); + } + + // range check the memory pointers + // TODO: do I need to consider the length of the input or state? + let shift = AB::Expr::from_canonical_usize( + 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), + ); + let needs_range_check = [ + local.instruction.dst_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + local.instruction.state_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + local.instruction.input_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + local.instruction.input_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + ]; + for pair in needs_range_check.chunks_exact(2) { + self.bitwise_lookup_bus + .send_range(pair[0] * shift.clone(), pair[1] * shift.clone()) + .eval(builder, *local.instruction.is_enabled); + } + + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(C::OPCODE as usize + self.offset), + [ + (*local.instruction.dst_reg_ptr).into(), + (*local.instruction.state_reg_ptr).into(), + (*local.instruction.input_reg_ptr).into(), + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + ], + *local.instruction.from_state, + AB::F::from_canonical_usize(C::TIMESTAMP_DELTA), + ) + .eval(builder, *local.instruction.is_enabled); + } + + pub fn eval_reads( + &self, + builder: &mut AB, + local: &Sha2ColsRef, + timestamp_pp: &mut impl FnMut() -> AB::Expr, + ) { + let input_ptr_val = compose(&local.instruction.input_ptr_limbs.to_vec(), RV32_CELL_BITS); + for i in 0..C::BLOCK_READS { + self.memory_bridge + .read::<_, _, SHA2_READ_SIZE>( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + input_ptr_val.clone() + AB::F::from_canonical_usize(i * SHA2_READ_SIZE), + ), + local + .block + .message_bytes + .slice(s![i * SHA2_READ_SIZE..(i + 1) * SHA2_READ_SIZE]) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("message bytes is not the correct size"); + }), + timestamp_pp(), + &local.mem.input_reads[i], + ) + .eval(builder, *local.instruction.is_enabled); + } + + let state_ptr_val = compose(&local.instruction.state_ptr_limbs.to_vec(), RV32_CELL_BITS); + for i in 0..C::STATE_READS { + self.memory_bridge + .read::<_, _, SHA2_READ_SIZE>( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + state_ptr_val.clone() + AB::F::from_canonical_usize(i * SHA2_READ_SIZE), + ), + local + .block + .prev_state + .slice(s![i * SHA2_READ_SIZE..(i + 1) * SHA2_READ_SIZE]) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("prev state is not the correct size"); + }), + timestamp_pp(), + &local.mem.state_reads[i], + ) + .eval(builder, *local.instruction.is_enabled); + } + } + + pub fn eval_writes( + &self, + builder: &mut AB, + local: &Sha2ColsRef, + timestamp_pp: &mut impl FnMut() -> AB::Expr, + ) { + let dst_ptr_val = compose(&local.instruction.dst_ptr_limbs.to_vec(), RV32_CELL_BITS); + for i in 0..C::STATE_READS { + self.memory_bridge + .write::<_, _, SHA2_WRITE_SIZE>( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val.clone() + AB::F::from_canonical_usize(i * SHA2_READ_SIZE), + ), + local + .block + .new_state + .slice(s![i * SHA2_READ_SIZE..(i + 1) * SHA2_READ_SIZE]) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("new state is not the correct size"); + }), + timestamp_pp(), + &local.mem.write_aux[i], + ) + .eval(builder, *local.instruction.is_enabled); + } + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/columns.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/columns.rs new file mode 100644 index 0000000000..23c0fad6b2 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/columns.rs @@ -0,0 +1,74 @@ +use openvm_circuit::{ + arch::ExecutionState, + system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, +}; +use openvm_circuit_primitives::ColsRef; +use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; + +use crate::{Sha2MainChipConfig, SHA2_REGISTER_READS, SHA2_WRITE_SIZE}; + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2MainChipConfig)] +pub struct Sha2Cols { + pub block: Sha2BlockCols, + pub instruction: Sha2InstructionCols, + pub mem: Sha2MemoryCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2MainChipConfig)] +pub struct Sha2BlockCols { + /// Identifier of this row in the interactions between the two chips + pub request_id: T, + /// Input bytes for this block + pub message_bytes: [T; BLOCK_BYTES], + // Previous state of the SHA-2 hasher object + pub prev_state: [T; STATE_BYTES], + // New state of the SHA-2 hasher object after processing this block + pub new_state: [T; STATE_BYTES], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2MainChipConfig)] +pub struct Sha2InstructionCols { + /// True for all rows that are part of opcode execution. + /// False on dummy rows only used to pad the height. + pub is_enabled: T, + #[aligned_borrow] + pub from_state: ExecutionState, + /// Pointer to address space 1 `dst` register + pub dst_reg_ptr: T, + /// Pointer to address space 1 `state` register + pub state_reg_ptr: T, + /// Pointer to address space 1 `input` register + pub input_reg_ptr: T, + // Register values + /// dst_ptr_limbs <- \[dst_reg_ptr:4\]_1 + pub dst_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS], + /// state_ptr_limbs <- \[state_reg_ptr:4\]_1 + pub state_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS], + /// input_ptr_limbs <- \[input_reg_ptr:4\]_1 + pub input_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS], +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2MainChipConfig)] +pub struct Sha2MemoryCols< + T, + const BLOCK_READS: usize, + const STATE_READS: usize, + const STATE_WRITES: usize, +> { + #[aligned_borrow] + pub register_aux: [MemoryReadAuxCols; SHA2_REGISTER_READS], + #[aligned_borrow] + pub input_reads: [MemoryReadAuxCols; BLOCK_READS], + #[aligned_borrow] + pub state_reads: [MemoryReadAuxCols; STATE_READS], + #[aligned_borrow] + pub write_aux: [MemoryWriteAuxCols; STATE_WRITES], +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/config.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/config.rs new file mode 100644 index 0000000000..5aec638971 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/config.rs @@ -0,0 +1,42 @@ +use openvm_sha2_air::{Sha256Config, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; + +use crate::{Sha2ColsRef, SHA2_READ_SIZE, SHA2_REGISTER_READS, SHA2_WRITE_SIZE}; + +pub trait Sha2MainChipConfig: Send + Sync + Clone { + // --- Required --- + /// Number of bytes in a SHA block (sometimes referred to as message bytes in the code) + const BLOCK_BYTES: usize; + /// Number of bytes in a SHA state + const STATE_BYTES: usize; + /// OpenVM Opcode for the instruction + const OPCODE: Rv32Sha2Opcode; + + // --- Provided --- + const BLOCK_READS: usize = Self::BLOCK_BYTES / SHA2_READ_SIZE; + const STATE_READS: usize = Self::STATE_BYTES / SHA2_READ_SIZE; + const STATE_WRITES: usize = Self::STATE_BYTES / SHA2_WRITE_SIZE; + + const TIMESTAMP_DELTA: usize = + Self::BLOCK_READS + Self::STATE_READS + Self::STATE_WRITES + SHA2_REGISTER_READS; + + const MAIN_CHIP_WIDTH: usize = Sha2ColsRef::::width::(); +} + +impl Sha2MainChipConfig for Sha256Config { + const BLOCK_BYTES: usize = 64; + const STATE_BYTES: usize = 32; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA256; +} + +impl Sha2MainChipConfig for Sha512Config { + const BLOCK_BYTES: usize = 128; + const STATE_BYTES: usize = 64; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA512; +} + +impl Sha2MainChipConfig for Sha384Config { + const BLOCK_BYTES: usize = Sha512Config::BLOCK_BYTES; + const STATE_BYTES: usize = Sha512Config::STATE_BYTES; + const OPCODE: Rv32Sha2Opcode = Sha512Config::OPCODE; +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/mod.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/mod.rs new file mode 100644 index 0000000000..5ae8935a05 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/mod.rs @@ -0,0 +1,52 @@ +mod air; +mod columns; +mod config; +mod trace; + +use std::{ + marker::PhantomData, + sync::{Arc, Mutex}, +}; + +pub use air::*; +pub use columns::*; +pub use config::*; +use openvm_circuit::{arch::VmChipWrapper, system::memory::SharedMemoryHelper}; +use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip; +use openvm_stark_backend::{p3_field::PrimeField32, p3_matrix::dense::RowMajorMatrix}; +pub use trace::*; + +use crate::{Sha2BlockHasherChip, Sha2Config}; + +pub struct Sha2MainChip { + // This Arc>> is shared with the block hasher chip (Sha2BlockHasherChip). + // When the main chip's tracegen is done, it will set the value of the mutex to Some(records) + // and then the block hasher chip can see the records and use them to generate its trace. + // The arc mutex is not strictly necessary (we could just use a Cell) because tracegen is done + // sequentially over the list of chips (although it is parallelized within each chip), but the + // overhead of using a thread-safe type is negligible since we only access the 'records' field + // twice (once to set the value and once to get the value). + // So, we will just use an arc mutex to avoid overcomplicating things. + pub records: Arc>>>, + pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + pub pointer_max_bits: usize, + pub mem_helper: SharedMemoryHelper, + _phantom: PhantomData, +} + +impl Sha2MainChip { + pub fn new( + records: Arc>>>, + bitwise_lookup_chip: SharedBitwiseOperationLookupChip<8>, + pointer_max_bits: usize, + mem_helper: SharedMemoryHelper, + ) -> Self { + Self { + records, + bitwise_lookup_chip, + pointer_max_bits, + mem_helper, + _phantom: PhantomData, + } + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/main_chip/trace.rs b/extensions/sha2/circuit/src/sha2_chips/main_chip/trace.rs new file mode 100644 index 0000000000..30e27d671b --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/main_chip/trace.rs @@ -0,0 +1,237 @@ +use std::{ + array::{self, from_fn}, + borrow::{Borrow, BorrowMut}, + cmp::min, + sync::Arc, +}; + +use ndarray::ArrayViewMut; +use openvm_circuit::{ + arch::*, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAuxColsFactory, + }, + utils::next_power_of_two_or_zero, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; +use openvm_sha2_air::set_arrayview_from_u8_slice; +use openvm_sha2_transpiler::Rv32Sha2Opcode; +use openvm_stark_backend::{ + config::{StarkGenericConfig, Val}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::*, + prover::{cpu::CpuBackend, types::AirProvingContext}, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, + Chip, +}; + +use crate::{ + Sha2ColsRef, Sha2ColsRefMut, Sha2Config, Sha2MainChip, Sha2Metadata, Sha2RecordHeader, + Sha2RecordLayout, Sha2RecordMut, SHA2_WRITE_SIZE, +}; + +// We will allocate a new trace matrix instead of using the record arena directly, +// because we want to pass the record arena to Sha2BlockHasherChip when we are done. +impl Chip> + for Sha2MainChip, C> +where + Val: PrimeField32, + SC: StarkGenericConfig, + RA: RowMajorMatrixArena> + Send + Sync, +{ + fn generate_proving_ctx(&self, arena: RA) -> AirProvingContext> { + // Since Sha2Metadata::get_num_rows() = 1, the number of rows used is equal to the number of + // SHA-2 instructions executed. + let rows_used = arena.trace_offset() / arena.width(); + + // We will fill the trace into a separate buffer, because we want to pass the arena to the + // Sha2BlockHasherChip when we are done. + // Sha2MainChip uses 1 row per instruction, we allocate rows_used * arena.width() space for + // the trace. + let height = next_power_of_two_or_zero(rows_used); + let trace = Val::::zero_vec(height * arena.width()); + let mut trace_matrix = RowMajorMatrix::new(trace, arena.width()); + let mem_helper = self.mem_helper.as_borrowed(); + + let mut records = arena.into_matrix(); + + self.fill_trace(&mem_helper, &mut trace_matrix, rows_used, &mut records); + + // Pass the records to Sha2BlockHasherChip + *self.records.lock().unwrap() = Some(records); + + AirProvingContext::simple(Arc::new(trace_matrix), vec![]) + } +} + +// Note: we would like to just impl TraceFiller here, but we can't because we need to pass the +// records and row_idx to the tracegen functions. +impl Sha2MainChip { + // Preconditions: + // - trace should be a matrix with width = Sha2MainAir::width() and height = rows_used + // - trace should be filled with all zeros + // - records should be a matrix with height = rows_used, where each row stores a record + pub fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace: &mut RowMajorMatrix, + rows_used: usize, + records: &mut RowMajorMatrix, + ) { + let width = trace.width(); + trace.values[..rows_used * width] + .par_chunks_exact_mut(width) + .zip(records.par_rows_mut()) + .enumerate() + .for_each(|(row_idx, (row_slice, record))| { + self.fill_trace_row_with_row_idx(mem_helper, row_slice, row_idx, record); + }); + } + + // Same as TraceFiller::fill_trace_row, except we also take the row index as a parameter. + // + // Note: the only reason the record parameter is mutable is that get_record_from_slice + // requires a &mut &mut [F] slice. This parameter type is useful in other places where + // get_record_from_slice is used, to circumvent the borrow checker. Here, we don't actually need + // this workaround (we could duplicate get_record_from_slice and modify it to take a &mut + // [F] slice), but we just use the existing function for simplicity. + fn fill_trace_row_with_row_idx( + &self, + mem_helper: &MemoryAuxColsFactory, + row_slice: &mut [F], + row_idx: usize, + mut record: &mut [F], + ) where + F: Clone, + { + // SAFETY: + // - caller ensures `record` contains a valid record representation that was previously + // written by the executor + // - record contains a valid Sha2RecordMut with the exact layout specified + // - get_record_from_slice will correctly split the buffer into header and other components + // based on this layout. + let record: Sha2RecordMut = unsafe { + get_record_from_slice( + &mut record, + Sha2RecordLayout::new(Sha2Metadata { + variant: C::VARIANT, + }), + ) + }; + + // save all the components of the record on the stack so that we don't overwrite them when + // filling in the trace matrix. + let vm_record = record.inner.clone(); + + let mut message_bytes = Vec::with_capacity(C::BLOCK_BYTES); + message_bytes.extend_from_slice(record.message_bytes); + + let mut prev_state = Vec::with_capacity(C::STATE_BYTES); + prev_state.extend_from_slice(record.prev_state); + + let mut new_state = prev_state.clone(); + C::compress(&mut new_state, &message_bytes); + + let mut input_reads_aux = + Vec::with_capacity(C::BLOCK_READS * size_of::()); + input_reads_aux.extend_from_slice(record.input_reads_aux); + + let mut state_reads_aux = + Vec::with_capacity(C::STATE_READS * size_of::()); + state_reads_aux.extend_from_slice(record.state_reads_aux); + + let mut write_aux = Vec::with_capacity( + C::STATE_WRITES * size_of::>(), + ); + write_aux.extend_from_slice(record.write_aux); + + let mut cols = Sha2ColsRefMut::from::(row_slice); + + *cols.block.request_id = F::from_canonical_usize(row_idx); + set_arrayview_from_u8_slice(&mut cols.block.message_bytes, message_bytes); + set_arrayview_from_u8_slice(&mut cols.block.prev_state, prev_state); + set_arrayview_from_u8_slice(&mut cols.block.new_state, new_state); + + *cols.instruction.is_enabled = F::ONE; + cols.instruction.from_state.timestamp = F::from_canonical_u32(vm_record.timestamp); + cols.instruction.from_state.pc = F::from_canonical_u32(vm_record.from_pc); + *cols.instruction.dst_reg_ptr = F::from_canonical_u32(vm_record.dst_reg_ptr); + *cols.instruction.state_reg_ptr = F::from_canonical_u32(vm_record.state_reg_ptr); + *cols.instruction.input_reg_ptr = F::from_canonical_u32(vm_record.input_reg_ptr); + + let dst_ptr_limbs = vm_record.dst_ptr.to_le_bytes(); + let state_ptr_limbs = vm_record.state_ptr.to_le_bytes(); + let input_ptr_limbs = vm_record.input_ptr.to_le_bytes(); + set_arrayview_from_u8_slice(&mut cols.instruction.dst_ptr_limbs, dst_ptr_limbs); + set_arrayview_from_u8_slice(&mut cols.instruction.state_ptr_limbs, state_ptr_limbs); + set_arrayview_from_u8_slice(&mut cols.instruction.input_ptr_limbs, input_ptr_limbs); + let needs_range_check = [ + dst_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + state_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + input_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + input_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1], + ]; + let shift: u32 = 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits); + for pair in needs_range_check.chunks_exact(2) { + self.bitwise_lookup_chip + .request_range(pair[0] as u32 * shift, pair[1] as u32 * shift); + } + + // fill in the register reads aux + let mut timestamp = vm_record.timestamp; + for (cols, vm_record) in cols + .mem + .register_aux + .iter_mut() + .zip(vm_record.register_reads_aux.iter()) + { + mem_helper.fill(vm_record.prev_timestamp, timestamp, cols.as_mut()); + timestamp += 1; + } + + input_reads_aux.iter().zip(cols.mem.input_reads).for_each( + |(read_aux_record, read_aux_cols)| { + mem_helper.fill( + read_aux_record.prev_timestamp, + timestamp, + read_aux_cols.as_mut(), + ); + timestamp += 1; + }, + ); + + state_reads_aux.iter().zip(cols.mem.state_reads).for_each( + |(state_aux_record, state_aux_cols)| { + mem_helper.fill( + state_aux_record.prev_timestamp, + timestamp, + state_aux_cols.as_mut(), + ); + timestamp += 1; + }, + ); + + write_aux + .iter() + .zip(cols.mem.write_aux) + .for_each(|(write_aux_record, write_aux_cols)| { + mem_helper.fill( + write_aux_record.prev_timestamp, + timestamp, + write_aux_cols.as_mut(), + ); + timestamp += 1; + }); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/mod.rs b/extensions/sha2/circuit/src/sha2_chips/mod.rs new file mode 100644 index 0000000000..4378e4787e --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/mod.rs @@ -0,0 +1,35 @@ +mod block_hasher_chip; +mod config; +mod execution; +mod main_chip; +mod trace; + +use std::marker::PhantomData; + +pub use block_hasher_chip::*; +pub use config::*; +pub use execution::*; +pub use main_chip::*; +pub use trace::*; + +#[cfg(test)] +mod test_utils; +#[cfg(test)] +mod tests; +#[cfg(test)] +pub use test_utils::*; + +#[derive(derive_new::new, Clone)] +pub struct Sha2VmExecutor { + pub offset: usize, + pub pointer_max_bits: usize, + _phantom: PhantomData, +} + +// Indicates the message type of the interactions on the sha bus +#[repr(u8)] +pub enum MessageType { + State, + Message1, + Message2, +} diff --git a/extensions/sha2/circuit/src/sha2_chips/old/air.rs b/extensions/sha2/circuit/src/sha2_chips/old/air.rs new file mode 100644 index 0000000000..4e57c9c16d --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/old/air.rs @@ -0,0 +1,376 @@ +use std::{cmp::min, convert::TryInto}; + +use openvm_circuit::{ + arch::{ExecutionBridge, SystemPort}, + system::memory::{ + offline_checker::{MemoryBridge, MemoryWriteAuxCols}, + MemoryAddress, + }, +}; +use openvm_circuit_primitives::{ + bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir, +}; +use openvm_instructions::{ + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_sha2_air::{compose, Sha256Config, Sha2Air, Sha2Variant, Sha512Config}; +use openvm_stark_backend::{ + interaction::{BusIndex, InteractionBuilder}, + p3_air::{Air, AirBuilder, BaseAir}, + p3_field::{Field, FieldAlgebra}, + p3_matrix::Matrix, + rap::{BaseAirWithPublicValues, PartitionedBaseAir}, +}; + +use super::{Sha2BlockHasherDigestColsRef, Sha2BlockHasherRoundColsRef, Sha2ChipConfig}; + +#[derive(Clone, Debug)] +pub struct Sha2BlockHasherAir { + /// Bus to send byte checks to + pub bitwise_lookup_bus: BitwiseOperationLookupBus, + pub(super) sha_subair: Sha2Air, +} + +impl Sha2BlockHasherAir { + pub fn new( + SystemPort { + execution_bus, + program_bus, + memory_bridge, + }: SystemPort, + bitwise_lookup_bus: BitwiseOperationLookupBus, + ptr_max_bits: usize, + self_bus_idx: BusIndex, + ) -> Self { + Self { + bitwise_lookup_bus, + sha_subair: Sha2Air::::new(bitwise_lookup_bus, self_bus_idx), + } + } +} + +impl BaseAirWithPublicValues for Sha2BlockHasherAir {} +impl PartitionedBaseAir for Sha2BlockHasherAir {} +impl BaseAir for Sha2BlockHasherAir { + fn width(&self) -> usize { + C::VM_WIDTH + } +} + +impl Air for Sha2BlockHasherAir { + fn eval(&self, builder: &mut AB) { + self.eval_transitions(builder); + self.eval_reads(builder); + self.eval_last_row(builder); + + self.sha_subair.eval(builder, C::BLOCK_HASHER_CONTROL_WIDTH); + } +} + +impl Sha2BlockHasherAir { + /// Implement constraints on `len`, `read_ptr` and `cur_timestamp` + fn eval_transitions(&self, builder: &mut AB) { + let main = builder.main(); + let (local, next) = (main.row_slice(0), main.row_slice(1)); + let local_cols = + Sha2BlockHasherRoundColsRef::::from::(&local[..C::VM_ROUND_WIDTH]); + let next_cols = + Sha2BlockHasherRoundColsRef::::from::(&next[..C::VM_ROUND_WIDTH]); + + let is_last_row = + *local_cols.inner.flags.is_last_block * *local_cols.inner.flags.is_digest_row; + // Len should be the same for the entire message + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq(*next_cols.control.len, *local_cols.control.len); + + // Read ptr should increment by [C::READ_SIZE] for the first 4 rows and stay the same + // otherwise + let read_ptr_delta = + *local_cols.inner.flags.is_first_4_rows * AB::Expr::from_canonical_usize(C::READ_SIZE); + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq( + *next_cols.control.read_ptr, + *local_cols.control.read_ptr + read_ptr_delta, + ); + + // Timestamp should increment by 1 for the first 4 rows and stay the same otherwise + let timestamp_delta = *local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE; + builder + .when_transition() + .when(not::(is_last_row.clone())) + .assert_eq( + *next_cols.control.cur_timestamp, + *local_cols.control.cur_timestamp + timestamp_delta, + ); + } + + /// Implement the reads for the first 4 rows of a block + fn eval_reads(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local_cols = + Sha2BlockHasherRoundColsRef::::from::(&local[..C::VM_ROUND_WIDTH]); + + let message: Vec = (0..C::READ_SIZE) + .map(|i| { + local_cols.inner.message_schedule.carry_or_buffer + [[i / (C::WORD_U16S * 2), i % (C::WORD_U16S * 2)]] + }) + .collect(); + + match C::VARIANT { + Sha2Variant::Sha256 => { + let message: [AB::Var; Sha256Config::READ_SIZE] = + message.try_into().unwrap_or_else(|_| { + panic!("message is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + *local_cols.control.read_ptr, + ), + message, + *local_cols.control.cur_timestamp, + local_cols.read_aux, + ) + .eval(builder, *local_cols.inner.flags.is_first_4_rows); + } + // Sha512 and Sha384 have the same read size so we put them together + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let message: [AB::Var; Sha512Config::READ_SIZE] = + message.try_into().unwrap_or_else(|_| { + panic!("message is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + *local_cols.control.read_ptr, + ), + message, + *local_cols.control.cur_timestamp, + local_cols.read_aux, + ) + .eval(builder, *local_cols.inner.flags.is_first_4_rows); + } + } + } + /// Implement the constraints for the last row of a message + fn eval_last_row(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0); + let local_cols = + Sha2BlockHasherDigestColsRef::::from::(&local[..C::VM_DIGEST_WIDTH]); + + let timestamp: AB::Var = local_cols.from_state.timestamp; + let mut timestamp_delta: usize = 0; + let mut timestamp_pp = || { + timestamp_delta += 1; + timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) + }; + + let is_last_row = + *local_cols.inner.flags.is_last_block * *local_cols.inner.flags.is_digest_row; + + let dst_ptr: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.dst_ptr.to_vec().try_into().unwrap_or_else(|_| { + panic!("dst_ptr is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rd_ptr, + ), + dst_ptr, + timestamp_pp(), + &local_cols.register_reads_aux[0], + ) + .eval(builder, is_last_row.clone()); + + let src_ptr: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.src_ptr.to_vec().try_into().unwrap_or_else(|_| { + panic!("src_ptr is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rs1_ptr, + ), + src_ptr, + timestamp_pp(), + &local_cols.register_reads_aux[1], + ) + .eval(builder, is_last_row.clone()); + + let len_data: [AB::Var; RV32_REGISTER_NUM_LIMBS] = + local_cols.len_data.to_vec().try_into().unwrap_or_else(|_| { + panic!("len_data is not the correct size"); + }); + self.memory_bridge + .read( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + *local_cols.rs2_ptr, + ), + len_data, + timestamp_pp(), + &local_cols.register_reads_aux[2], + ) + .eval(builder, is_last_row.clone()); + // range check that the memory pointers don't overflow + // Note: no need to range check the length since we read from memory step by step and + // the memory bus will catch any memory accesses beyond ptr_max_bits + let shift = AB::Expr::from_canonical_usize( + 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), + ); + // This only works if self.ptr_max_bits >= 24 which is typically the case + self.bitwise_lookup_bus + .send_range( + // It is fine to shift like this since we already know that dst_ptr and src_ptr + // have [RV32_CELL_BITS] bits + local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), + local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), + ) + .eval(builder, is_last_row.clone()); + + // the number of reads that happened to read the entire message: we do 4 reads per block + let time_delta = (*local_cols.inner.flags.local_block_idx + AB::Expr::ONE) + * AB::Expr::from_canonical_usize(4); + // Every time we read the message we increment the read pointer by C::READ_SIZE + let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(C::READ_SIZE); + + let result: Vec = (0..C::HASH_SIZE) + .map(|i| { + // The limbs are written in big endian order to the memory so need to be reversed + local_cols.inner.final_hash[[i / C::WORD_U8S, C::WORD_U8S - i % C::WORD_U8S - 1]] + }) + .collect(); + + let dst_ptr_val = compose::( + local_cols.dst_ptr.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + + match C::VARIANT { + Sha2Variant::Sha256 => { + debug_assert_eq!(C::NUM_WRITES, 1); + debug_assert_eq!(local_cols.writes_aux_base.len(), 1); + debug_assert_eq!(local_cols.writes_aux_prev_data.nrows(), 1); + let prev_data: [AB::Var; Sha256Config::HASH_SIZE] = local_cols + .writes_aux_prev_data + .row(0) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("writes_aux_prev_data is not the correct size"); + }); + // Note: revisit in the future to do 2 block writes of 16 cells instead of 1 block + // write of 32 cells. This could be beneficial as the output is often an input for + // another hash + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val, + ), + result.try_into().unwrap_or_else(|_| { + panic!("result is not the correct size"); + }), + timestamp_pp() + time_delta.clone(), + &MemoryWriteAuxCols::from_base(local_cols.writes_aux_base[0], prev_data), + ) + .eval(builder, is_last_row.clone()); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + debug_assert_eq!(C::NUM_WRITES, 2); + debug_assert_eq!(local_cols.writes_aux_base.len(), 2); + debug_assert_eq!(local_cols.writes_aux_prev_data.nrows(), 2); + + // For Sha384, set the last 16 cells to 0 + let mut truncated_result: Vec = + result.iter().map(|x| (*x).into()).collect(); + for x in truncated_result.iter_mut().skip(C::DIGEST_SIZE) { + *x = AB::Expr::ZERO; + } + + // write the digest in two halves because we only support writes up to 32 bytes + for i in 0..Sha512Config::NUM_WRITES { + let prev_data: [AB::Var; Sha512Config::WRITE_SIZE] = local_cols + .writes_aux_prev_data + .row(i) + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("writes_aux_prev_data is not the correct size"); + }); + + self.memory_bridge + .write( + MemoryAddress::new( + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + dst_ptr_val.clone() + + AB::Expr::from_canonical_usize(i * Sha512Config::WRITE_SIZE), + ), + truncated_result + [i * Sha512Config::WRITE_SIZE..(i + 1) * Sha512Config::WRITE_SIZE] + .to_vec() + .try_into() + .unwrap_or_else(|_| { + panic!("result is not the correct size"); + }), + timestamp_pp() + time_delta.clone(), + &MemoryWriteAuxCols::from_base( + local_cols.writes_aux_base[i], + prev_data, + ), + ) + .eval(builder, is_last_row.clone()); + } + } + } + self.execution_bridge + .execute_and_increment_pc( + AB::Expr::from_canonical_usize(C::OPCODE.global_opcode().as_usize()), + [ + >::into(*local_cols.rd_ptr), + >::into(*local_cols.rs1_ptr), + >::into(*local_cols.rs2_ptr), + AB::Expr::from_canonical_u32(RV32_REGISTER_AS), + AB::Expr::from_canonical_u32(RV32_MEMORY_AS), + ], + *local_cols.from_state, + AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(), + ) + .eval(builder, is_last_row.clone()); + + // Assert that we read the correct length of the message + let len_val = compose::( + local_cols.len_data.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + builder + .when(is_last_row.clone()) + .assert_eq(*local_cols.control.len, len_val); + // Assert that we started reading from the correct pointer initially + let src_val = compose::( + local_cols.src_ptr.mapv(|x| x.into()).as_slice().unwrap(), + RV32_CELL_BITS, + ); + builder + .when(is_last_row.clone()) + .assert_eq(*local_cols.control.read_ptr, src_val + read_ptr_delta); + // Assert that we started reading from the correct timestamp + builder.when(is_last_row.clone()).assert_eq( + *local_cols.control.cur_timestamp, + local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta, + ); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/old/columns.rs b/extensions/sha2/circuit/src/sha2_chips/old/columns.rs new file mode 100644 index 0000000000..ad6a01bfe2 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/old/columns.rs @@ -0,0 +1,105 @@ +//! WARNING: the order of fields in the structs is important, do not change it + +use openvm_circuit::{ + arch::ExecutionState, + system::memory::offline_checker::{MemoryBaseAuxCols, MemoryReadAuxCols}, +}; +use openvm_circuit_primitives_derive::ColsRef; +use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; +use openvm_sha2_air::{ + Sha2DigestCols, Sha2DigestColsRef, Sha2DigestColsRefMut, Sha2RoundCols, Sha2RoundColsRef, + Sha2RoundColsRefMut, +}; + +use super::{Sha2ChipConfig, SHA_REGISTER_READS}; + +/// the first C::ROUND_ROWS rows of every SHA block will be of type ShaVmRoundCols and the last row +/// will be of type ShaVmDigestCols +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2BlockHasherRoundCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, +> { + pub control: Sha2BlockHasherControlCols, + pub inner: Sha2RoundCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, + #[aligned_borrow] + pub read_aux: MemoryReadAuxCols, +} + +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2BlockHasherDigestCols< + T, + const WORD_BITS: usize, + const WORD_U8S: usize, + const WORD_U16S: usize, + const HASH_WORDS: usize, + const ROUNDS_PER_ROW: usize, + const ROUNDS_PER_ROW_MINUS_ONE: usize, + const ROW_VAR_CNT: usize, + const NUM_WRITES: usize, + const WRITE_SIZE: usize, +> { + pub control: Sha2BlockHasherControlCols, + pub inner: Sha2DigestCols< + T, + WORD_BITS, + WORD_U8S, + WORD_U16S, + HASH_WORDS, + ROUNDS_PER_ROW, + ROUNDS_PER_ROW_MINUS_ONE, + ROW_VAR_CNT, + >, + #[aligned_borrow] + pub from_state: ExecutionState, + /// It is counter intuitive, but we will constrain the register reads on the very last row of + /// every message + pub rd_ptr: T, + pub rs1_ptr: T, + pub rs2_ptr: T, + pub dst_ptr: [T; RV32_REGISTER_NUM_LIMBS], + pub src_ptr: [T; RV32_REGISTER_NUM_LIMBS], + pub len_data: [T; RV32_REGISTER_NUM_LIMBS], + #[aligned_borrow] + pub register_reads_aux: [MemoryReadAuxCols; SHA_REGISTER_READS], + // We store the fields of MemoryWriteAuxCols here because the length of prev_data depends on + // the sha variant + #[aligned_borrow] + pub writes_aux_base: [MemoryBaseAuxCols; NUM_WRITES], + pub writes_aux_prev_data: [[T; WRITE_SIZE]; NUM_WRITES], +} + +/// These are the columns that are used on both round and digest rows +#[repr(C)] +#[derive(Clone, Copy, Debug, ColsRef)] +#[config(Sha2ChipConfig)] +pub struct Sha2BlockHasherControlCols { + /// Note: We will use the buffer in `inner.message_schedule` as the message data + /// This is the length of the entire message in bytes + pub len: T, + /// Need to keep timestamp and read_ptr since block reads don't have the necessary information + pub cur_timestamp: T, + pub read_ptr: T, + /// Padding flags which will be used to encode the the number of non-padding cells in the + /// current row + pub pad_flags: [T; 9], + /// A boolean flag that indicates whether a padding already occurred + pub padding_occurred: T, +} diff --git a/extensions/sha2/circuit/src/sha2_chips/old/config-old.rs b/extensions/sha2/circuit/src/sha2_chips/old/config-old.rs new file mode 100644 index 0000000000..8bd79288a8 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/old/config-old.rs @@ -0,0 +1,101 @@ +use openvm_instructions::riscv::RV32_CELL_BITS; +use openvm_sha2_air::{Sha256Config, Sha2Config, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; + +use super::{ + Sha2BlockHasherControlColsRef, Sha2BlockHasherDigestColsRef, Sha2BlockHasherRoundColsRef, +}; + +pub trait Sha2ChipConfig: Sha2BlockHasherConfig { + // Name of the opcode + const OPCODE_NAME: &'static str; + /// Width of the ShaVmControlCols + const BLOCK_HASHER_CONTROL_WIDTH: usize = Sha2BlockHasherControlColsRef::::width::(); + /// Width of the ShaVmRoundCols + const VM_ROUND_WIDTH: usize = Sha2BlockHasherRoundColsRef::::width::(); + /// Width of the ShaVmDigestCols + const VM_DIGEST_WIDTH: usize = Sha2BlockHasherDigestColsRef::::width::(); + /// Width of the ShaVmCols + const VM_WIDTH: usize = if Self::VM_ROUND_WIDTH > Self::VM_DIGEST_WIDTH { + Self::VM_ROUND_WIDTH + } else { + Self::VM_DIGEST_WIDTH + }; + /// Number of bits to use when padding the message length. Given by the SHA-2 spec. + const MESSAGE_LENGTH_BITS: usize; + /// Maximum i such that `FirstPadding_i` is a valid padding flag + const MAX_FIRST_PADDING: usize = Self::CELLS_PER_ROW - 1; + /// Maximum i such that `FirstPadding_i_LastRow` is a valid padding flag + const MAX_FIRST_PADDING_LAST_ROW: usize = + Self::CELLS_PER_ROW - Self::MESSAGE_LENGTH_BITS / 8 - 1; + /// OpenVM Opcode for the instruction + const OPCODE: Rv32Sha2Opcode; + + // ==== Constants for register/memory adapter ==== + /// Number of rv32 cells read in a block + const BLOCK_CELLS: usize = Self::BLOCK_BITS / RV32_CELL_BITS; + /// Number of rows we will do a read on for each block + const NUM_READ_ROWS: usize = Self::MESSAGE_ROWS; + + /// Number of cells to read in a single memory access + const READ_SIZE: usize = Self::WORD_U8S * Self::ROUNDS_PER_ROW; + /// Number of cells in the digest before truncation (Sha384 truncates the digest) + const HASH_SIZE: usize = Self::WORD_U8S * Self::HASH_WORDS; + /// Number of cells in the digest after truncation + const DIGEST_SIZE: usize; + + /// Number of parts to write the hash in + const NUM_WRITES: usize = Self::HASH_SIZE / Self::WRITE_SIZE; + /// Size of each write. Must divide Self::HASH_SIZE + const WRITE_SIZE: usize; +} + +/// Register reads to get dst, src, len +pub const SHA_REGISTER_READS: usize = 3; + +impl Sha2ChipConfig for Sha256Config { + const OPCODE_NAME: &'static str = "SHA256"; + const MESSAGE_LENGTH_BITS: usize = 64; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA256; + // no truncation + const DIGEST_SIZE: usize = Self::HASH_SIZE; +} + +impl Sha2ChipConfig for Sha512Config { + const OPCODE_NAME: &'static str = "SHA512"; + const MESSAGE_LENGTH_BITS: usize = 128; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA512; + // no truncation + const DIGEST_SIZE: usize = Self::HASH_SIZE; +} + +impl Sha2ChipConfig for Sha384Config { + const OPCODE_NAME: &'static str = "SHA384"; + const MESSAGE_LENGTH_BITS: usize = 128; + const WRITE_SIZE: usize = SHA_WRITE_SIZE; + const OPCODE: Rv32Sha2Opcode = Rv32Sha2Opcode::SHA384; + // Sha384 truncates the output to 48 cells + const DIGEST_SIZE: usize = 48; +} + +// We use the same write size for all variants to simplify tracegen record storage. +// In particular, each memory write aux record will have the same size, which is useful for +// defining Sha2VmRecordHeader in a repr(C) way. +pub const SHA_WRITE_SIZE: usize = 32; + +pub const MAX_SHA_NUM_WRITES: usize = if Sha256Config::NUM_WRITES > Sha512Config::NUM_WRITES { + if Sha256Config::NUM_WRITES > Sha384Config::NUM_WRITES { + Sha256Config::NUM_WRITES + } else { + Sha384Config::NUM_WRITES + } +} else if Sha512Config::NUM_WRITES > Sha384Config::NUM_WRITES { + Sha512Config::NUM_WRITES +} else { + Sha384Config::NUM_WRITES +}; + +/// Maximum message length that this chip supports in bytes +pub const SHA_MAX_MESSAGE_LEN: usize = 1 << 29; diff --git a/extensions/sha256/circuit/src/sha256_chip/cuda.rs b/extensions/sha2/circuit/src/sha2_chips/old/cuda.rs similarity index 100% rename from extensions/sha256/circuit/src/sha256_chip/cuda.rs rename to extensions/sha2/circuit/src/sha2_chips/old/cuda.rs diff --git a/extensions/sha256/circuit/src/sha256_chip/execution.rs b/extensions/sha2/circuit/src/sha2_chips/old/execution.rs similarity index 100% rename from extensions/sha256/circuit/src/sha256_chip/execution.rs rename to extensions/sha2/circuit/src/sha2_chips/old/execution.rs diff --git a/extensions/sha256/circuit/src/sha256_chip/mod.rs b/extensions/sha2/circuit/src/sha2_chips/old/mod.rs similarity index 100% rename from extensions/sha256/circuit/src/sha256_chip/mod.rs rename to extensions/sha2/circuit/src/sha2_chips/old/mod.rs diff --git a/extensions/sha2/circuit/src/sha2_chips/old/tests.rs b/extensions/sha2/circuit/src/sha2_chips/old/tests.rs new file mode 100644 index 0000000000..0fca041e74 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/old/tests.rs @@ -0,0 +1,323 @@ +use std::array; + +use openvm_circuit::{ + arch::{ + testing::{memory::gen_pointer, VmChipTestBuilder, BITWISE_OP_LOOKUP_BUS}, + DenseRecordArena, InsExecutorE1, InstructionExecutor, NewVmChipWrapper, + }, + utils::get_random_message, +}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip, +}; +use openvm_instructions::{instruction::Instruction, riscv::RV32_CELL_BITS, LocalOpcode}; +use openvm_sha2_air::{Sha256Config, Sha2Variant, Sha384Config, Sha512Config}; +use openvm_sha2_transpiler::Rv32Sha2Opcode; +use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; +use openvm_stark_sdk::{config::setup_tracing, p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; + +use super::{Sha2BlockHasheAir, Sha2ChipConfig, Sha2VmChip, Sha2VmStep}; +use crate::{ + sha2_chip::trace::Sha2VmRecordLayout, sha2_solve, Sha2BlockHasherDigestColsRef, + Sha2BlockHasherRoundColsRef, +}; + +type F = BabyBear; +const SELF_BUS_IDX: BusIndex = 28; +const MAX_INS_CAPACITY: usize = 8192; +type Sha2VmChipDense = + NewVmChipWrapper, Sha2VmStep, DenseRecordArena>; + +fn create_test_chips( + tester: &mut VmChipTestBuilder, +) -> ( + Sha2VmChip, + SharedBitwiseOperationLookupChip, +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + let mut chip = Sha2VmChip::::new( + Sha2BlockHasheAir::new( + tester.system_port(), + bitwise_bus, + tester.address_bits(), + SELF_BUS_IDX, + ), + Sha2VmStep::new( + bitwise_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + tester.address_bits(), + ), + tester.memory_helper(), + ); + chip.set_trace_height(MAX_INS_CAPACITY); + + (chip, bitwise_chip) +} + +fn set_and_execute, C: Sha2ChipConfig>( + tester: &mut VmChipTestBuilder, + chip: &mut E, + rng: &mut StdRng, + opcode: Rv32Sha2Opcode, + message: Option<&[u8]>, + len: Option, +) { + let len = len.unwrap_or(rng.gen_range(1..3000)); + let tmp = get_random_message(rng, len); + let message: &[u8] = message.unwrap_or(&tmp); + let len = message.len(); + + let rd = gen_pointer(rng, 4); + let rs1 = gen_pointer(rng, 4); + let rs2 = gen_pointer(rng, 4); + + let max_mem_ptr: u32 = 1 << tester.address_bits(); + let dst_ptr = rng.gen_range(0..max_mem_ptr - C::DIGEST_SIZE as u32); + let dst_ptr = dst_ptr ^ (dst_ptr & 3); + tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); + let src_ptr = rng.gen_range(0..(max_mem_ptr - len as u32)); + let src_ptr = src_ptr ^ (src_ptr & 3); + tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); + + message.chunks(4).enumerate().for_each(|(i, chunk)| { + let chunk: [&u8; 4] = array::from_fn(|i| chunk.get(i).unwrap_or(&0)); + tester.write( + 2, + src_ptr as usize + i * 4, + chunk.map(|&x| F::from_canonical_u8(x)), + ); + }); + + tester.execute( + chip, + &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), + ); + + let output = sha2_solve::(message); + match C::VARIANT { + Sha2Variant::Sha256 => { + assert_eq!( + output + .into_iter() + .map(F::from_canonical_u8) + .collect::>(), + tester.read::<{ Sha256Config::DIGEST_SIZE }>(2, dst_ptr as usize) + ); + } + Sha2Variant::Sha512 | Sha2Variant::Sha384 => { + let mut output = output; + output.extend(std::iter::repeat(0u8).take(C::HASH_SIZE)); + let output = output + .into_iter() + .map(F::from_canonical_u8) + .collect::>(); + for i in 0..C::NUM_WRITES { + assert_eq!( + output[i * C::WRITE_SIZE..(i + 1) * C::WRITE_SIZE], + tester.read::<{ Sha512Config::WRITE_SIZE }>( + 2, + dst_ptr as usize + i * C::WRITE_SIZE + ) + ); + } + } + } +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// +fn rand_sha_test() { + setup_tracing(); + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, bitwise_chip) = create_test_chips::(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute::<_, C>(&mut tester, &mut chip, &mut rng, C::OPCODE, None, None); + } + + let tester = tester.build().load(chip).load(bitwise_chip).finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn rand_sha256_test() { + rand_sha_test::(); +} + +#[test] +fn rand_sha512_test() { + rand_sha_test::(); +} + +#[test] +fn rand_sha384_test() { + rand_sha_test::(); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut chip, _) = create_test_chips::(&mut tester); + + println!( + "Sha2VmDigestColsRef::::width::(): {}", + Sha2BlockHasherDigestColsRef::::width::() + ); + println!( + "Sha2VmRoundColsRef::::width::(): {}", + Sha2BlockHasherRoundColsRef::::width::() + ); + + let num_tests: usize = 1; + for _ in 0..num_tests { + set_and_execute::<_, C>(&mut tester, &mut chip, &mut rng, C::OPCODE, None, None); + } +} + +#[test] +fn sha256_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha512_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha384_roundtrip_sanity_test() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha256_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = sha2_solve::(input); + let expected: [u8; 32] = [ + 99, 196, 61, 185, 226, 212, 131, 80, 154, 248, 97, 108, 157, 55, 200, 226, 160, 73, 207, + 46, 245, 169, 94, 255, 42, 136, 193, 15, 40, 133, 173, 22, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha512_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = sha2_solve::(input); + // verified manually against the sha512 command line tool + let expected: [u8; 64] = [ + 0, 8, 195, 142, 70, 71, 97, 208, 132, 132, 243, 53, 179, 186, 8, 162, 71, 75, 126, 21, 130, + 203, 245, 126, 207, 65, 119, 60, 64, 79, 200, 2, 194, 17, 189, 137, 164, 213, 107, 197, + 152, 11, 242, 165, 146, 80, 96, 105, 249, 27, 139, 14, 244, 21, 118, 31, 94, 87, 32, 145, + 149, 98, 235, 75, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha384_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = sha2_solve::(input); + let expected: [u8; 48] = [ + 134, 227, 167, 229, 35, 110, 115, 174, 10, 27, 197, 116, 56, 144, 150, 36, 152, 190, 212, + 120, 26, 243, 125, 4, 2, 60, 164, 195, 218, 219, 255, 143, 240, 75, 158, 126, 102, 105, 8, + 202, 142, 240, 230, 161, 162, 152, 111, 71, + ]; + assert_eq!(output, expected); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// DENSE TESTS +/// +/// Ensure that the chip works as expected with dense records. +/// We first execute some instructions with a [DenseRecordArena] and transfer the records +/// to a [MatrixRecordArena]. After transferring we generate the trace and make sure that +/// all the constraints pass. +/////////////////////////////////////////////////////////////////////////////////////// +fn create_test_chip_dense( + tester: &mut VmChipTestBuilder, +) -> Sha2VmChipDense { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = SharedBitwiseOperationLookupChip::::new(bitwise_bus); + + let mut chip = Sha2VmChipDense::::new( + Sha2BlockHasheAir::::new( + tester.system_port(), + bitwise_chip.bus(), + tester.address_bits(), + SELF_BUS_IDX, + ), + Sha2VmStep::::new( + bitwise_chip.clone(), + Rv32Sha2Opcode::CLASS_OFFSET, + tester.address_bits(), + ), + tester.memory_helper(), + ); + + chip.set_trace_buffer_height(MAX_INS_CAPACITY); + chip +} + +fn dense_record_arena_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut sparse_chip, bitwise_chip) = create_test_chips::(&mut tester); + + { + let mut dense_chip = create_test_chip_dense::(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute::<_, C>( + &mut tester, + &mut dense_chip, + &mut rng, + C::OPCODE, + None, + None, + ); + } + + let mut record_interpreter = dense_chip + .arena + .get_record_seeker::<_, Sha2VmRecordLayout>(); + record_interpreter.transfer_to_matrix_arena(&mut sparse_chip.arena); + } + + let tester = tester + .build() + .load(sparse_chip) + .load(bitwise_chip) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn sha256_dense_record_arena_test() { + dense_record_arena_test::(); +} + +#[test] +fn sha512_dense_record_arena_test() { + dense_record_arena_test::(); +} + +#[test] +fn sha384_dense_record_arena_test() { + dense_record_arena_test::(); +} diff --git a/extensions/sha2/circuit/src/sha2_chips/old/trace.rs b/extensions/sha2/circuit/src/sha2_chips/old/trace.rs new file mode 100644 index 0000000000..5f6beeb80d --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/old/trace.rs @@ -0,0 +1,723 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + cmp::min, + iter, + marker::PhantomData, +}; + +use openvm_circuit::{ + arch::{ + get_record_from_slice, CustomBorrow, MultiRowLayout, MultiRowMetadata, RecordArena, Result, + SizedRecord, TraceFiller, TraceStep, VmStateMut, + }, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + MemoryAuxColsFactory, + }, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; +use openvm_sha2_air::{ + be_limbs_into_word, get_flag_pt_array, Sha256Config, Sha2StepHelper, Sha384Config, Sha512Config, +}; +use openvm_stark_backend::{ + p3_field::PrimeField32, + p3_matrix::{dense::RowMajorMatrix, Matrix}, + p3_maybe_rayon::prelude::*, +}; + +use super::{ + Sha2BlockHasherDigestColsRefMut, Sha2BlockHasherRoundColsRefMut, Sha2ChipConfig, Sha2Variant, + Sha2VmStep, +}; +use crate::{ + get_sha2_num_blocks, sha2_chip::PaddingFlags, sha2_solve, Sha2BlockHasherControlColsRefMut, + MAX_SHA_NUM_WRITES, SHA_MAX_MESSAGE_LEN, SHA_REGISTER_READS, SHA_WRITE_SIZE, +}; + +#[derive(Clone, Copy)] +pub struct Sha2VmMetadata { + pub num_blocks: u32, + _phantom: PhantomData, +} + +impl MultiRowMetadata for Sha2VmMetadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + self.num_blocks as usize * C::ROWS_PER_BLOCK + } +} + +pub(crate) type Sha2VmRecordLayout = MultiRowLayout>; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Sha2VmRecordHeader { + pub from_pc: u32, + pub timestamp: u32, + pub rd_ptr: u32, + pub rs1_ptr: u32, + pub rs2_ptr: u32, + pub dst_ptr: u32, + pub src_ptr: u32, + pub len: u32, + + pub register_reads_aux: [MemoryReadAuxRecord; SHA_REGISTER_READS], + // Note: MAX_SHA_NUM_WRITES = 2 because SHA-256 uses 1 write, while SHA-512 and SHA-384 use 2 + // writes. We just use the same array for all variants to simplify record storage. + pub writes_aux: [MemoryWriteBytesAuxRecord; MAX_SHA_NUM_WRITES], +} + +pub struct Sha2VmRecordMut<'a> { + pub inner: &'a mut Sha2VmRecordHeader, + // Having a continuous slice of the input is useful for fast hashing in `execute` + pub input: &'a mut [u8], + pub read_aux: &'a mut [MemoryReadAuxRecord], +} + +/// Custom borrowing that splits the buffer into a fixed `Sha2VmRecord` header +/// followed by a slice of `u8`'s of length `C::BLOCK_CELLS * num_blocks` where `num_blocks` is +/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length +/// `C::NUM_READ_ROWS * num_blocks`. Uses `align_to_mut()` to make sure the slice is properly +/// aligned to `MemoryReadAuxRecord`. Has debug assertions that check the size and alignment of the +/// slices. +impl<'a, C: Sha2ChipConfig> CustomBorrow<'a, Sha2VmRecordMut<'a>, Sha2VmRecordLayout> + for [u8] +{ + fn custom_borrow(&'a mut self, layout: Sha2VmRecordLayout) -> Sha2VmRecordMut<'a> { + let (header_buf, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let header: &mut Sha2VmRecordHeader = header_buf.borrow_mut(); + + // Using `split_at_mut_unchecked` for perf reasons + // input is a slice of `u8`'s of length `C::BLOCK_CELLS * num_blocks`, so the alignment + // is always satisfied + let (input, rest) = unsafe { + rest.split_at_mut_unchecked((layout.metadata.num_blocks as usize) * C::BLOCK_CELLS) + }; + + // Using `align_to_mut` to make sure the returned slice is properly aligned to + // `MemoryReadAuxRecord` Additionally, Rust's subslice operation (a few lines below) + // will verify that the buffer has enough capacity + let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; + Sha2VmRecordMut { + inner: header, + input, + read_aux: &mut read_aux_buf[..(layout.metadata.num_blocks as usize) * C::NUM_READ_ROWS], + } + } + + unsafe fn extract_layout(&self) -> Sha2VmRecordLayout { + let header: &Sha2VmRecordHeader = self.borrow(); + + Sha2VmRecordLayout { + metadata: Sha2VmMetadata { + num_blocks: get_sha2_num_blocks::(header.len), + _phantom: PhantomData::, + }, + } + } +} + +impl SizedRecord> for Sha2VmRecordMut<'_> { + fn size(layout: &Sha2VmRecordLayout) -> usize { + let mut total_len = size_of::(); + total_len += layout.metadata.num_blocks as usize * C::BLOCK_CELLS; + // Align the pointer to the alignment of `MemoryReadAuxRecord` + total_len = total_len.next_multiple_of(align_of::()); + total_len += layout.metadata.num_blocks as usize + * C::NUM_READ_ROWS + * size_of::(); + total_len + } + + fn alignment(_layout: &Sha2VmRecordLayout) -> usize { + align_of::() + } +} + +impl TraceStep for Sha2VmStep { + type RecordLayout = Sha2VmRecordLayout; + type RecordMut<'a> = Sha2VmRecordMut<'a>; + + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", C::OPCODE) + } + + fn execute<'buf, RA>( + &mut self, + state: VmStateMut, CTX>, + instruction: &Instruction, + arena: &'buf mut RA, + ) -> Result<()> + where + RA: RecordArena<'buf, Self::RecordLayout, Self::RecordMut<'buf>>, + { + let Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + debug_assert_eq!(*opcode, C::OPCODE.global_opcode()); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + // Reading the length first to allocate a record of correct size + let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()); + + let num_blocks = get_sha2_num_blocks::(len); + let record = arena.alloc(MultiRowLayout { + metadata: Sha2VmMetadata { + num_blocks, + _phantom: PhantomData::, + }, + }); + + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp(); + record.inner.rd_ptr = a.as_canonical_u32(); + record.inner.rs1_ptr = b.as_canonical_u32(); + record.inner.rs2_ptr = c.as_canonical_u32(); + + record.inner.dst_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rd_ptr, + &mut record.inner.register_reads_aux[0].prev_timestamp, + )); + record.inner.src_ptr = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs1_ptr, + &mut record.inner.register_reads_aux[1].prev_timestamp, + )); + record.inner.len = u32::from_le_bytes(tracing_read( + state.memory, + RV32_REGISTER_AS, + record.inner.rs2_ptr, + &mut record.inner.register_reads_aux[2].prev_timestamp, + )); + + // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used + debug_assert!( + record.inner.src_ptr as usize + num_blocks as usize * C::BLOCK_CELLS + <= (1 << self.pointer_max_bits) + ); + debug_assert!( + record.inner.dst_ptr as usize + C::WRITE_SIZE <= (1 << self.pointer_max_bits) + ); + // We don't support messages longer than 2^29 bytes + debug_assert!(record.inner.len < SHA_MAX_MESSAGE_LEN as u32); + + for block_idx in 0..num_blocks as usize { + // Reads happen on the first 4 rows of each block + for row in 0..C::NUM_READ_ROWS { + let read_idx = block_idx * C::NUM_READ_ROWS + row; + match C::VARIANT { + Sha2Variant::Sha256 => { + let row_input: [u8; Sha256Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + Sha2Variant::Sha512 => { + let row_input: [u8; Sha512Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + Sha2Variant::Sha384 => { + let row_input: [u8; Sha384Config::READ_SIZE] = tracing_read( + state.memory, + RV32_MEMORY_AS, + record.inner.src_ptr + (read_idx * C::READ_SIZE) as u32, + &mut record.read_aux[read_idx].prev_timestamp, + ); + record.input[read_idx * C::READ_SIZE..(read_idx + 1) * C::READ_SIZE] + .copy_from_slice(&row_input); + } + } + } + } + + let mut output = sha2_solve::(&record.input[..len as usize]); + match C::VARIANT { + Sha2Variant::Sha256 => { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr, + output.try_into().unwrap(), + &mut record.inner.writes_aux[0].prev_timestamp, + &mut record.inner.writes_aux[0].prev_data, + ); + } + Sha2Variant::Sha512 => { + debug_assert!(output.len() % Sha512Config::WRITE_SIZE == 0); + output + .chunks_exact(Sha512Config::WRITE_SIZE) + .enumerate() + .for_each(|(i, chunk)| { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (i * Sha512Config::WRITE_SIZE) as u32, + chunk.try_into().unwrap(), + &mut record.inner.writes_aux[i].prev_timestamp, + &mut record.inner.writes_aux[i].prev_data, + ); + }); + } + Sha2Variant::Sha384 => { + // output is a truncated 48-byte digest, so we will append 16 bytes of zeros + output.extend(iter::repeat(0).take(16)); + debug_assert!(output.len() % Sha384Config::WRITE_SIZE == 0); + output + .chunks_exact(Sha384Config::WRITE_SIZE) + .enumerate() + .for_each(|(i, chunk)| { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (i * Sha384Config::WRITE_SIZE) as u32, + chunk.try_into().unwrap(), + &mut record.inner.writes_aux[i].prev_timestamp, + &mut record.inner.writes_aux[i].prev_data, + ); + }); + } + } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + + Ok(()) + } +} + +impl TraceFiller for Sha2VmStep { + fn fill_trace( + &self, + mem_helper: &MemoryAuxColsFactory, + trace_matrix: &mut RowMajorMatrix, + rows_used: usize, + ) { + if rows_used == 0 { + return; + } + + let mut chunks = Vec::with_capacity(trace_matrix.height() / C::ROWS_PER_BLOCK); + let mut sizes = Vec::with_capacity(trace_matrix.height() / C::ROWS_PER_BLOCK); + let mut trace = &mut trace_matrix.values[..]; + let mut num_blocks_so_far = 0; + + // First pass over the trace to get the number of blocks for each instruction + // and divide the matrix into chunks of needed sizes + loop { + if num_blocks_so_far * C::ROWS_PER_BLOCK >= rows_used { + // Push all the padding rows as a single chunk and break + chunks.push(trace); + sizes.push((0, num_blocks_so_far)); + break; + } else { + let record: &Sha2VmRecordHeader = unsafe { get_record_from_slice(&mut trace, ()) }; + let num_blocks = get_sha2_num_blocks::(record.len) as usize; + let (chunk, rest) = + trace.split_at_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK * num_blocks); + chunks.push(chunk); + sizes.push((num_blocks, num_blocks_so_far)); + num_blocks_so_far += num_blocks; + trace = rest; + } + } + + // During the first pass we will fill out most of the matrix + // But there are some cells that can't be generated by the first pass so we will do a second + // pass over the matrix later + chunks.par_iter_mut().zip(sizes.par_iter()).for_each( + |(slice, (num_blocks, global_block_offset))| { + if global_block_offset * C::ROWS_PER_BLOCK >= rows_used { + // Fill in the invalid rows + slice.par_chunks_mut(C::VM_WIDTH).for_each(|row| { + // Need to get rid of the accidental garbage data that might overflow the + // F's prime field. Unfortunately, there is no good way around this + unsafe { + std::ptr::write_bytes( + row.as_mut_ptr() as *mut u8, + 0, + C::VM_WIDTH * size_of::(), + ); + } + let cols = Sha2BlockHasherRoundColsRefMut::::from::( + row[..C::VM_ROUND_WIDTH].borrow_mut(), + ); + self.inner.generate_default_row(cols.inner); + }); + return; + } + + let record: Sha2VmRecordMut = unsafe { + get_record_from_slice( + slice, + Sha2VmRecordLayout { + metadata: Sha2VmMetadata { + num_blocks: *num_blocks as u32, + _phantom: PhantomData::, + }, + }, + ) + }; + + let mut input: Vec = Vec::with_capacity(C::BLOCK_CELLS * num_blocks); + input.extend_from_slice(record.input); + let mut padded_input = input.clone(); + let len = record.inner.len as usize; + let padded_input_len = padded_input.len(); + padded_input[len] = 1 << (RV32_CELL_BITS - 1); + padded_input[len + 1..padded_input_len - 4].fill(0); + padded_input[padded_input_len - 4..] + .copy_from_slice(&((len as u32) << 3).to_be_bytes()); + + let mut prev_hashes = Vec::with_capacity(*num_blocks); + prev_hashes.push(C::get_h().to_vec()); + for i in 0..*num_blocks - 1 { + prev_hashes.push(Sha2StepHelper::::get_block_hash( + &prev_hashes[i], + padded_input[i * C::BLOCK_CELLS..(i + 1) * C::BLOCK_CELLS].into(), + )); + } + // Copy the read aux records and input to another place to safely fill in the trace + // matrix without overwriting the record + let mut read_aux_records = Vec::with_capacity(C::NUM_READ_ROWS * num_blocks); + read_aux_records.extend_from_slice(record.read_aux); + let vm_record = record.inner.clone(); + + slice + .par_chunks_exact_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK) + .enumerate() + .for_each(|(block_idx, block_slice)| { + // Need to get rid of the accidental garbage data that might overflow the + // F's prime field. Unfortunately, there is no good way around this + unsafe { + std::ptr::write_bytes( + block_slice.as_mut_ptr() as *mut u8, + 0, + C::ROWS_PER_BLOCK * C::VM_WIDTH * size_of::(), + ); + } + self.fill_block_trace::( + block_slice, + &vm_record, + &read_aux_records + [block_idx * C::NUM_READ_ROWS..(block_idx + 1) * C::NUM_READ_ROWS], + &input[block_idx * C::BLOCK_CELLS..(block_idx + 1) * C::BLOCK_CELLS], + &padded_input + [block_idx * C::BLOCK_CELLS..(block_idx + 1) * C::BLOCK_CELLS], + block_idx == *num_blocks - 1, + *global_block_offset + block_idx, + block_idx, + prev_hashes[block_idx].as_slice(), + mem_helper, + ); + }); + }, + ); + + // Do a second pass over the trace to fill in the missing values + // Note, we need to skip the very first row + trace_matrix.values[C::VM_WIDTH..] + .par_chunks_mut(C::VM_WIDTH * C::ROWS_PER_BLOCK) + .take(rows_used / C::ROWS_PER_BLOCK) + .for_each(|chunk| { + self.inner.generate_missing_cells( + chunk, + C::VM_WIDTH, + C::BLOCK_HASHER_CONTROL_WIDTH, + ); + }); + } +} + +impl Sha2VmStep { + #[allow(clippy::too_many_arguments)] + fn fill_block_trace( + &self, + block_slice: &mut [F], + record: &Sha2VmRecordHeader, + read_aux_records: &[MemoryReadAuxRecord], + input: &[u8], + padded_input: &[u8], + is_last_block: bool, + global_block_idx: usize, + local_block_idx: usize, + prev_hash: &[C::Word], + mem_helper: &MemoryAuxColsFactory, + ) { + debug_assert_eq!(input.len(), C::BLOCK_CELLS); + debug_assert_eq!(padded_input.len(), C::BLOCK_CELLS); + debug_assert_eq!(read_aux_records.len(), C::NUM_READ_ROWS); + debug_assert_eq!(prev_hash.len(), C::HASH_WORDS); + + let padded_input = (0..C::BLOCK_WORDS) + .map(|i| { + be_limbs_into_word::( + &padded_input[i * C::WORD_U8S..(i + 1) * C::WORD_U8S] + .iter() + .map(|x| *x as u32) + .collect::>(), + ) + }) + .collect::>(); + + let block_start_timestamp = + record.timestamp + (SHA_REGISTER_READS + C::NUM_READ_ROWS * local_block_idx) as u32; + + let read_cells = (C::BLOCK_CELLS * local_block_idx) as u32; + let block_start_read_ptr = record.src_ptr + read_cells; + + let message_left = if record.len <= read_cells { + 0 + } else { + (record.len - read_cells) as usize + }; + + // -1 means that padding occurred before the start of the block + // C::ROWS_PER_BLOCK + 1 means that no padding occurred on this block + let first_padding_row = if record.len < read_cells { + -1 + } else if message_left < C::BLOCK_CELLS { + (message_left / C::READ_SIZE) as i32 + } else { + (C::ROWS_PER_BLOCK + 1) as i32 + }; + + // Fill in the VM columns first because the inner `carry_or_buffer` needs to be filled in + block_slice + .par_chunks_exact_mut(C::VM_WIDTH) + .enumerate() + .for_each(|(row_idx, row_slice)| { + // Handle round rows and digest row separately + if row_idx == C::ROWS_PER_BLOCK - 1 { + // This is a digest row + let mut digest_cols = Sha2BlockHasherDigestColsRefMut::::from::( + row_slice[..C::VM_DIGEST_WIDTH].borrow_mut(), + ); + digest_cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); + digest_cols.from_state.pc = F::from_canonical_u32(record.from_pc); + *digest_cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); + *digest_cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); + *digest_cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); + digest_cols + .dst_ptr + .iter_mut() + .zip(record.dst_ptr.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + digest_cols + .src_ptr + .iter_mut() + .zip(record.src_ptr.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + digest_cols + .len_data + .iter_mut() + .zip(record.len.to_le_bytes().map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + if is_last_block { + digest_cols + .register_reads_aux + .iter_mut() + .zip(record.register_reads_aux.iter()) + .enumerate() + .for_each(|(idx, (cols_read, record_read))| { + mem_helper.fill( + record_read.prev_timestamp, + record.timestamp + idx as u32, + cols_read.as_mut(), + ); + }); + for i in 0..C::NUM_WRITES { + digest_cols + .writes_aux_prev_data + .row_mut(i) + .iter_mut() + .zip(record.writes_aux[i].prev_data.map(F::from_canonical_u8)) + .for_each(|(x, y)| *x = y); + + // In the last block we do `C::NUM_READ_ROWS` reads and then write the + // result thus the timestamp of the write is + // `block_start_timestamp + C::NUM_READ_ROWS` + mem_helper.fill( + record.writes_aux[i].prev_timestamp, + block_start_timestamp + C::NUM_READ_ROWS as u32 + i as u32, + &mut digest_cols.writes_aux_base[i], + ); + } + // Need to range check the destination and source pointers + let msl_rshift: u32 = + ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; + let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS + - self.pointer_max_bits) + as u32; + self.bitwise_lookup_chip.request_range( + (record.dst_ptr >> msl_rshift) << msl_lshift, + (record.src_ptr >> msl_rshift) << msl_lshift, + ); + } else { + // Filling in zeros to make sure the accidental garbage data doesn't + // overflow the prime + digest_cols.register_reads_aux.iter_mut().for_each(|aux| { + mem_helper.fill_zero(aux.as_mut()); + }); + for i in 0..C::NUM_WRITES { + digest_cols.writes_aux_prev_data.row_mut(i).fill(F::ZERO); + mem_helper.fill_zero(&mut digest_cols.writes_aux_base[i]); + } + } + *digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); + *digest_cols.inner.flags.is_digest_row = F::from_bool(true); + } else { + // This is a round row + let mut round_cols = Sha2BlockHasherRoundColsRefMut::::from::( + row_slice[..C::VM_ROUND_WIDTH].borrow_mut(), + ); + // Take care of the first 4 round rows (aka read rows) + if row_idx < C::NUM_READ_ROWS { + round_cols + .inner + .message_schedule + .carry_or_buffer + .iter_mut() + .zip(input[row_idx * C::READ_SIZE..(row_idx + 1) * C::READ_SIZE].iter()) + .for_each(|(cell, data)| { + *cell = F::from_canonical_u8(*data); + }); + mem_helper.fill( + read_aux_records[row_idx].prev_timestamp, + block_start_timestamp + row_idx as u32, + round_cols.read_aux.as_mut(), + ); + } else { + mem_helper.fill_zero(round_cols.read_aux.as_mut()); + } + } + // Fill in the control cols, doesn't matter if it is a round or digest row + let mut control_cols = Sha2BlockHasherControlColsRefMut::::from::( + row_slice[..C::BLOCK_HASHER_CONTROL_WIDTH].borrow_mut(), + ); + *control_cols.len = F::from_canonical_u32(record.len); + // Only the first `SHA256_NUM_READ_ROWS` rows increment the timestamp and read ptr + *control_cols.cur_timestamp = F::from_canonical_u32( + block_start_timestamp + min(row_idx, C::NUM_READ_ROWS) as u32, + ); + *control_cols.read_ptr = F::from_canonical_u32( + block_start_read_ptr + (C::READ_SIZE * min(row_idx, C::NUM_READ_ROWS)) as u32, + ); + + // Fill in the padding flags + if row_idx < C::NUM_READ_ROWS { + #[allow(clippy::comparison_chain)] + if (row_idx as i32) < first_padding_row { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotPadding as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } else if row_idx as i32 == first_padding_row { + let len = message_left - row_idx * C::READ_SIZE; + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + if row_idx == 3 && is_last_block { + PaddingFlags::FirstPadding0_LastRow + } else { + PaddingFlags::FirstPadding0 + } as usize + + len, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } else { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + if row_idx == 3 && is_last_block { + PaddingFlags::EntirePaddingLastRow + } else { + PaddingFlags::EntirePadding + } as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + } else { + control_cols + .pad_flags + .iter_mut() + .zip( + get_flag_pt_array( + &self.padding_encoder, + PaddingFlags::NotConsidered as usize, + ) + .into_iter() + .map(F::from_canonical_u32), + ) + .for_each(|(x, y)| *x = y); + } + if is_last_block && row_idx == C::ROWS_PER_BLOCK - 1 { + // If last digest row, then we set padding_occurred = 0 + *control_cols.padding_occurred = F::ZERO; + } else { + *control_cols.padding_occurred = + F::from_bool((row_idx as i32) >= first_padding_row); + } + }); + + // Fill in the inner trace when the `carry_or_buffer` is filled in + self.inner.generate_block_trace::( + block_slice, + C::VM_WIDTH, + C::BLOCK_HASHER_CONTROL_WIDTH, + &padded_input, + self.bitwise_lookup_chip.clone(), + prev_hash, + is_last_block, + global_block_idx as u32 + 1, // global block index is 1-indexed + local_block_idx as u32, + ); + } +} diff --git a/extensions/sha2/circuit/src/sha2_chips/old/utils.rs b/extensions/sha2/circuit/src/sha2_chips/old/utils.rs new file mode 100644 index 0000000000..d3c78345ad --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/old/utils.rs @@ -0,0 +1,8 @@ +use crate::Sha2ChipConfig; + +/// Returns the number of blocks required to hash a message of length `len` +pub fn get_sha2_num_blocks(len: u32) -> u32 { + // need to pad with one 1 bit, 64 bits for the message length and then pad until the length + // is divisible by [C::BLOCK_BITS] + ((len << 3) as usize + 1 + C::MESSAGE_LENGTH_BITS).div_ceil(C::BLOCK_BITS) as u32 +} diff --git a/extensions/sha2/circuit/src/sha2_chips/test_utils.rs b/extensions/sha2/circuit/src/sha2_chips/test_utils.rs new file mode 100644 index 0000000000..e434f9b111 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/test_utils.rs @@ -0,0 +1,126 @@ +use std::{ + array, + borrow::BorrowMut, + sync::{Arc, Mutex}, +}; + +use hex::FromHex; +use itertools::Itertools; +use openvm_circuit::{ + arch::{ + testing::{ + memory::gen_pointer, TestBuilder, TestChipHarness, VmChipTestBuilder, + BITWISE_OP_LOOKUP_BUS, + }, + Arena, ExecutionBridge, PreflightExecutor, + }, + system::{ + memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, + SystemPort, + }, + utils::get_random_message, +}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, +}; +use openvm_instructions::{ + instruction::Instruction, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS}, + LocalOpcode, +}; +use openvm_sha2_air::{ + word_into_u8_limbs, Sha256Config, Sha2BlockHasherSubairConfig, Sha2Variant, Sha512Config, +}; +use openvm_sha2_transpiler::Rv32Sha2Opcode::{self, *}; +use openvm_stark_backend::{ + interaction::BusIndex, + p3_field::{FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +#[cfg(feature = "cuda")] +use { + crate::{trace::Sha2BlockHasherRecordMut, Sha2BlockHasherChipGpu}, + openvm_circuit::arch::testing::{ + default_bitwise_lookup_bus, GpuChipTestBuilder, GpuTestChipHarness, + }, +}; + +use crate::{ + Sha2BlockHasherChip, Sha2BlockHasherVmAir, Sha2Config, Sha2MainAir, Sha2MainChip, + Sha2MainChipConfig, Sha2VmExecutor, SHA2_READ_SIZE, SHA2_WRITE_SIZE, +}; + +// See https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf for the padding algorithm +pub fn add_padding_to_message(mut message: Vec) -> Vec { + let message_len_bits = message.len() * 8; + + // For SHA-256, + // l + 1 + k = 448 mod 512 + // <=> l + 1 + k + 8 = 0 mod 512 + // <=> k = -(l + 1 + 8) mod 512 + // <=> k = (512 - (l + 1 + 8)) mod 512 + // The other variants are similar. + let padding_len_bits = (C::BLOCK_BITS + - ((message_len_bits + 1 + C::MESSAGE_LENGTH_BITS) % C::BLOCK_BITS)) + % C::BLOCK_BITS; + message.push(0x80); + + let padding_len_bytes = padding_len_bits / 8; + message.extend(std::iter::repeat_n(0x00, padding_len_bytes)); + + match C::VARIANT { + Sha2Variant::Sha256 => { + message.extend_from_slice(&((message_len_bits as u64).to_be_bytes())); + } + Sha2Variant::Sha512 => { + message.extend_from_slice(&((message_len_bits as u128).to_be_bytes())); + } + Sha2Variant::Sha384 => { + message.extend_from_slice(&((message_len_bits as u128).to_be_bytes())); + } + }; + + message +} + +pub fn write_slice_to_memory( + tester: &mut impl TestBuilder, + data: &[u8], + ptr: usize, +) { + data.chunks_exact(4).enumerate().for_each(|(i, chunk)| { + tester.write::( + RV32_MEMORY_AS as usize, + ptr + i * 4, + chunk + .iter() + .cloned() + .map(F::from_canonical_u8) + .collect_vec() + .try_into() + .unwrap(), + ); + }); +} + +pub fn read_slice_from_memory( + tester: &mut impl TestBuilder, + ptr: usize, + len: usize, +) -> Vec { + let mut data = Vec::new(); + for i in 0..(len / SHA2_READ_SIZE) { + data.extend_from_slice( + &tester.read::(RV32_MEMORY_AS as usize, ptr + i * SHA2_READ_SIZE), + ); + } + data +} diff --git a/extensions/sha2/circuit/src/sha2_chips/tests.rs b/extensions/sha2/circuit/src/sha2_chips/tests.rs new file mode 100644 index 0000000000..12332f6019 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/tests.rs @@ -0,0 +1,728 @@ +use std::{ + array, + borrow::BorrowMut, + sync::{Arc, Mutex}, +}; + +use hex::FromHex; +use itertools::Itertools; +use openvm_circuit::{ + arch::{ + testing::{ + memory::gen_pointer, TestBuilder, TestChipHarness, VmChipTestBuilder, + BITWISE_OP_LOOKUP_BUS, + }, + Arena, ExecutionBridge, MatrixRecordArena, PreflightExecutor, RowMajorMatrixArena, + SizedRecord, + }, + system::{ + memory::{offline_checker::MemoryBridge, SharedMemoryHelper}, + SystemPort, + }, + utils::get_random_message, +}; +use openvm_circuit_primitives::bitwise_op_lookup::{ + BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, + SharedBitwiseOperationLookupChip, +}; +use openvm_instructions::{ + instruction::Instruction, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS}, + LocalOpcode, +}; +use openvm_sha2_air::{ + word_into_u8_limbs, Sha256Config, Sha2BlockHasherSubairConfig, Sha2DigestColsRef, + Sha2RoundColsRef, Sha2Variant, Sha384Config, Sha512Config, +}; +use openvm_sha2_transpiler::Rv32Sha2Opcode::{self, *}; +use openvm_stark_backend::{ + interaction::BusIndex, + p3_field::{Field, FieldAlgebra, PrimeField32}, + p3_matrix::{ + dense::{DenseMatrix, RowMajorMatrix}, + Matrix, + }, + utils::disable_debug_builder, + verifier::VerificationError, +}; +use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; +use rand::{rngs::StdRng, Rng}; +#[cfg(feature = "cuda")] +use { + crate::{trace::Sha2BlockHasherRecordMut, Sha2BlockHasherChipGpu}, + openvm_circuit::arch::testing::{ + default_bitwise_lookup_bus, GpuChipTestBuilder, GpuTestChipHarness, + }, +}; + +use crate::{ + add_padding_to_message, read_slice_from_memory, write_slice_to_memory, Sha2BlockHasherChip, + Sha2BlockHasherVmAir, Sha2BlockHasherVmConfig, Sha2BlockHasherVmDigestColsRefMut, Sha2Config, + Sha2MainAir, Sha2MainChip, Sha2MainChipConfig, Sha2Metadata, Sha2RecordLayout, Sha2RecordMut, + Sha2VmExecutor, SHA2_READ_SIZE, SHA2_WRITE_SIZE, +}; + +const SHA2_BUS_IDX: BusIndex = 28; +type F = BabyBear; +const MAX_INS_CAPACITY: usize = 4096; +type Harness = TestChipHarness, Sha2MainAir, Sha2MainChip, RA>; + +fn create_harness_fields( + system_port: SystemPort, + bitwise_chip: SharedBitwiseOperationLookupChip, + memory_helper: SharedMemoryHelper, + pointer_max_bits: usize, +) -> (Sha2MainAir, Sha2VmExecutor, Sha2MainChip) { + let executor = Sha2VmExecutor::::new(Rv32Sha2Opcode::CLASS_OFFSET, pointer_max_bits); + let empty_records = Arc::new(Mutex::new(None)); + let main_chip = Sha2MainChip::new( + empty_records.clone(), + bitwise_chip.clone(), + pointer_max_bits, + memory_helper, + ); + let main_air = Sha2MainAir::new( + system_port, + bitwise_chip.bus(), + pointer_max_bits, + SHA2_BUS_IDX, + Rv32Sha2Opcode::CLASS_OFFSET, + ); + (main_air, executor, main_chip) +} + +fn create_test_harness( + tester: &mut VmChipTestBuilder, +) -> ( + Harness, + ( + BitwiseOperationLookupAir, + SharedBitwiseOperationLookupChip, + ), + (Sha2BlockHasherVmAir, Sha2BlockHasherChip), +) { + let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); + let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let (air, executor, main_chip) = create_harness_fields( + tester.system_port(), + bitwise_chip.clone(), + tester.memory_helper(), + tester.address_bits(), + ); + + let shared_records = main_chip.records.clone(); + + let harness = Harness::::with_capacity(executor, air, main_chip, MAX_INS_CAPACITY); + + let block_hasher_air = Sha2BlockHasherVmAir::new(bitwise_chip.bus(), SHA2_BUS_IDX); + let block_hasher_chip = Sha2BlockHasherChip::new( + bitwise_chip.clone(), + tester.address_bits(), + tester.memory_helper(), + shared_records, + ); + + ( + harness, + (bitwise_chip.air, bitwise_chip), + (block_hasher_air, block_hasher_chip), + ) +} + +// execute one SHA2_UPDATE instruction +#[allow(clippy::too_many_arguments)] +fn set_and_execute_single_block>( + tester: &mut impl TestBuilder, + executor: &mut E, + arena: &mut RA, + rng: &mut StdRng, + opcode: Rv32Sha2Opcode, + message: Option<&[u8]>, + prev_state: Option<&[u8]>, +) { + let rd = gen_pointer(rng, 4); + let rs1 = gen_pointer(rng, 4); + let rs2 = gen_pointer(rng, 4); + + let dst_ptr = gen_pointer(rng, 4); + let state_ptr = gen_pointer(rng, 4); + let input_ptr = gen_pointer(rng, 4); + tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs1, state_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs2, input_ptr.to_le_bytes().map(F::from_canonical_u8)); + + let default_message = get_random_message(rng, C::BLOCK_U8S); + let message = message.unwrap_or(&default_message); + assert!(message.len() == C::BLOCK_U8S); + write_slice_to_memory(tester, message, input_ptr); + + let default_prev_state = get_random_message(rng, C::STATE_BYTES); + let prev_state = prev_state.unwrap_or(&default_prev_state); + assert!(prev_state.len() == C::STATE_BYTES); + write_slice_to_memory(tester, prev_state, state_ptr); + + tester.execute( + executor, + arena, + &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), + ); + + let mut state = prev_state.to_vec(); + C::compress(&mut state, message); + let expected_output = state + .iter() + .cloned() + .map(F::from_canonical_u8) + .collect_vec(); + + assert_eq!( + expected_output, + read_slice_from_memory(tester, dst_ptr, C::STATE_BYTES) + ); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS - Single Block Hash +/// +/// Randomly generate computations and execute, ensuring that the generated trace +/// passes all constraints. +/////////////////////////////////////////////////////////////////////////////////////// +// Test a single block hash +fn rand_sha2_single_block_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise, block_hasher) = + create_test_harness::, C>(&mut tester); + + // let num_ops: usize = 10; + let num_ops: usize = 1; + for _ in 0..num_ops { + set_and_execute_single_block::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + None, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(block_hasher) + .load_periphery(bitwise) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn debug() { + println!("main chip width: {}", Sha256Config::MAIN_CHIP_WIDTH); + println!( + "block hasher chip width: {}", + Sha256Config::BLOCK_HASHER_WIDTH + ); + println!("main chip width: {}", Sha512Config::MAIN_CHIP_WIDTH); + println!( + "block hasher chip width: {}", + Sha512Config::BLOCK_HASHER_WIDTH + ); + println!("main chip width: {}", Sha384Config::MAIN_CHIP_WIDTH); + println!( + "block hasher chip width: {}", + Sha384Config::BLOCK_HASHER_WIDTH + ); + println!(); + println!( + "Sha2RecordMut::size(Sha256): {}", + Sha2RecordMut::size(&Sha2RecordLayout { + metadata: Sha2Metadata { + variant: Sha2Variant::Sha256 + } + }) + ); + println!( + "Sha2RecordMut::size(Sha512): {}", + Sha2RecordMut::size(&Sha2RecordLayout { + metadata: Sha2Metadata { + variant: Sha2Variant::Sha512 + } + }) + ); + println!( + "Sha2RecordMut::size(Sha384): {}", + Sha2RecordMut::size(&Sha2RecordLayout { + metadata: Sha2Metadata { + variant: Sha2Variant::Sha384 + } + }) + ); + println!( + "Sha256: Sha2BlockHasherVmWidth: {}", + Sha256Config::BLOCK_HASHER_WIDTH + ); + println!( + "Sha512: Sha2BlockHasherVmWidth: {}", + Sha512Config::BLOCK_HASHER_WIDTH + ); + println!( + "Sha384: Sha2BlockHasherVmWidth: {}", + Sha384Config::BLOCK_HASHER_WIDTH + ); + println!( + "Sha256: Sha2BlockHasherSubairWidth: {}", + Sha256Config::SUBAIR_WIDTH + ); + println!( + "Sha512: Sha2BlockHasherSubairWidth: {}", + Sha512Config::SUBAIR_WIDTH + ); + println!( + "Sha384: Sha2BlockHasherSubairWidth: {}", + Sha384Config::SUBAIR_WIDTH + ); +} + +#[test] +fn rand_sha256_single_block_test() { + rand_sha2_single_block_test::(); +} + +#[test] +fn rand_sha512_single_block_test() { + rand_sha2_single_block_test::(); +} + +// Note that this test is actually the same as rand_sha512_single_block_test, but we include it here +// for completeness. +#[test] +fn rand_sha384_single_block_test() { + rand_sha2_single_block_test::(); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// POSITIVE TESTS - Multi Block Hash +/// +/// Execute multiple SHA2_UPDATE instructions to hash an entire message +/////////////////////////////////////////////////////////////////////////////////////// +#[allow(clippy::too_many_arguments)] +fn set_and_execute_full_message>( + tester: &mut impl TestBuilder, + executor: &mut E, + arena: &mut RA, + rng: &mut StdRng, + opcode: Rv32Sha2Opcode, + message: Option<&[u8]>, + len: Option, +) { + let rd = gen_pointer(rng, 4); + let rs1 = gen_pointer(rng, 4); + let rs2 = gen_pointer(rng, 4); + + let state_ptr = gen_pointer(rng, 4); + let dst_ptr = state_ptr; + let input_ptr = gen_pointer(rng, 4); + tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs1, state_ptr.to_le_bytes().map(F::from_canonical_u8)); + tester.write(1, rs2, input_ptr.to_le_bytes().map(F::from_canonical_u8)); + + let initial_state: Vec = C::get_h() + .iter() + .cloned() + .flat_map(|x| word_into_u8_limbs::(x).into_iter().rev()) + .map(|x| x.try_into().unwrap()) + .collect_vec(); + assert!(initial_state.len() == C::STATE_BYTES); + write_slice_to_memory(tester, &initial_state, state_ptr); + + let len = len.unwrap_or(rng.gen_range(1..3000)); + let default_message = get_random_message(rng, len); + let message = message.map(|x| x.to_vec()).unwrap_or(default_message); + + let expected_output = C::hash(&message); + + let padded_message = add_padding_to_message::(message); + + // run SHA2_UPDATE as many times as needed to hash the entire message + padded_message + .chunks_exact(C::BLOCK_BYTES) + .for_each(|block| { + write_slice_to_memory(tester, block, input_ptr); + + tester.execute( + executor, + arena, + &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), + ); + }); + + let output = read_slice_from_memory(tester, dst_ptr, C::STATE_BYTES) + .into_iter() + .map(|x| x.as_canonical_u32() as u8) + .collect_vec(); + + assert_eq!(expected_output, output); +} + +// Test a single block hash +fn rand_sha2_multi_block_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise, block_hasher) = create_test_harness::<_, C>(&mut tester); + + let num_ops: usize = 10; + for _ in 0..num_ops { + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + None, + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .load_periphery(block_hasher) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn rand_sha256_multi_block_test() { + rand_sha2_multi_block_test::(); +} + +#[test] +fn rand_sha512_multi_block_test() { + rand_sha2_multi_block_test::(); +} + +// Note that this test is different from rand_sha512_multi_block_test because it uses the initial +// hash state for SHA384 instead of SHA512. +#[test] +fn rand_sha384_multi_block_test() { + rand_sha2_multi_block_test::(); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// EDGE TESTS - Edge Case Input Lengths +/// +/// Test the hash function with various input lengths. +/////////////////////////////////////////////////////////////////////////////////////// +fn sha2_edge_test_lengths(test_vectors: &[(&str, &str)]) { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, bitwise, block_hasher) = create_test_harness::<_, C>(&mut tester); + + for (input, _) in test_vectors.iter() { + let input = Vec::from_hex(input).unwrap(); + + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + Some(&input), + None, + ); + } + + // check every possible input length modulo 64 + for i in 65..=128 { + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + Some(i), + ); + } + + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .load_periphery(block_hasher) + .finalize(); + tester.simple_test().expect("Verification failed"); +} + +#[test] +fn sha256_edge_test_lengths() { + let test_vectors = [ + ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), + ( + "98c1c0bdb7d5fea9a88859f06c6c439f", + "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", + ), + ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), + ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") + ]; + + sha2_edge_test_lengths::(&test_vectors); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// SANITY TESTS +/// +/// Ensure that solve functions produce the correct results. +/////////////////////////////////////////////////////////////////////////////////////// +fn execute_roundtrip_sanity_test() { + let mut rng = create_seeded_rng(); + let mut tester = VmChipTestBuilder::default(); + let (mut harness, _, _) = create_test_harness::, C>(&mut tester); + + println!( + "Sha2DigestCols::width(): {}", + Sha2DigestColsRef::::width::() + ); + println!( + "Sha2RoundCols::width(): {}", + Sha2RoundColsRef::::width::() + ); + + let num_tests: usize = 1; + for _ in 0..num_tests { + set_and_execute_full_message::<_, C, _>( + &mut tester, + &mut harness.executor, + &mut harness.arena, + &mut rng, + C::OPCODE, + None, + None, + ); + } +} + +#[test] +fn execute_roundtrip_sanity_test_sha256() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn execute_roundtrip_sanity_test_sha512() { + execute_roundtrip_sanity_test::(); +} +#[test] +fn execute_roundtrip_sanity_test_sha384() { + execute_roundtrip_sanity_test::(); +} + +#[test] +fn sha256_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = Sha256Config::hash(input); + let expected: [u8; 32] = [ + 99, 196, 61, 185, 226, 212, 131, 80, 154, 248, 97, 108, 157, 55, 200, 226, 160, 73, 207, + 46, 245, 169, 94, 255, 42, 136, 193, 15, 40, 133, 173, 22, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha512_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = Sha512Config::hash(input); + let expected: [u8; 64] = [ + 0, 8, 195, 142, 70, 71, 97, 208, 132, 132, 243, 53, 179, 186, 8, 162, 71, 75, 126, 21, 130, + 203, 245, 126, 207, 65, 119, 60, 64, 79, 200, 2, 194, 17, 189, 137, 164, 213, 107, 197, + 152, 11, 242, 165, 146, 80, 96, 105, 249, 27, 139, 14, 244, 21, 118, 31, 94, 87, 32, 145, + 149, 98, 235, 75, + ]; + assert_eq!(output, expected); +} + +#[test] +fn sha384_solve_sanity_check() { + let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; + let output = Sha384Config::hash(input); + let expected: [u8; 48] = [ + 134, 227, 167, 229, 35, 110, 115, 174, 10, 27, 197, 116, 56, 144, 150, 36, 152, 190, 212, + 120, 26, 243, 125, 4, 2, 60, 164, 195, 218, 219, 255, 143, 240, 75, 158, 126, 102, 105, 8, + 202, 142, 240, 230, 161, 162, 152, 111, 71, + ]; + assert_eq!(output, expected); +} + +/////////////////////////////////////////////////////////////////////////////////////// +/// NEGATIVE TESTS +/// +/// This tests a soundness bug that was found at one point in our implementation. +/////////////////////////////////////////////////////////////////////////////////////// +fn negative_sha2_test_bad_final_hash() { + let mut tester = VmChipTestBuilder::default(); + let (harness, bitwise, block_hasher) = + create_test_harness::, C>(&mut tester); + + // Set the final_hash to all zeros + let modify_trace = |trace: &mut RowMajorMatrix| { + trace.row_chunks_exact_mut(1).for_each(|row| { + let mut row_slice = row.row_slice(0).to_vec(); + let mut cols = Sha2BlockHasherVmDigestColsRefMut::from::( + &mut row_slice[..C::BLOCK_HASHER_DIGEST_WIDTH], + ); + if cols.inner.flags.is_last_block.is_one() && cols.inner.flags.is_digest_row.is_one() { + for i in 0..C::HASH_WORDS { + for j in 0..C::WORD_U8S { + cols.inner.final_hash[[i, j]] = F::ZERO; + } + } + row.values.copy_from_slice(&row_slice); + } + }); + }; + + disable_debug_builder(); + let tester = tester + .build() + .load(harness) + .load_periphery(bitwise) + .load_periphery_and_prank_trace(block_hasher, modify_trace) + .finalize(); + tester.simple_test_with_expected_error(VerificationError::OodEvaluationMismatch); +} + +#[test] +fn negative_sha256_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +#[test] +fn negative_sha512_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +#[test] +fn negative_sha384_test_bad_final_hash() { + negative_sha2_test_bad_final_hash::(); +} + +// //////////////////////////////////////////////////////////////////////////////////// +// CUDA TESTS +// +// Ensure GPU tracegen is equivalent to CPU tracegen +// //////////////////////////////////////////////////////////////////////////////////// + +#[cfg(feature = "cuda")] +type GpuHarness = + GpuTestChipHarness>; + +#[cfg(feature = "cuda")] +fn create_cuda_harness(tester: &GpuChipTestBuilder) -> GpuHarness { + const GPU_MAX_INS_CAPACITY: usize = 8192; + + let bitwise_bus = default_bitwise_lookup_bus(); + let dummy_bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( + bitwise_bus, + )); + + let (air, executor, cpu_chip) = create_harness_fields( + tester.system_port(), + dummy_bitwise_chip, + tester.dummy_memory_helper(), + tester.address_bits(), + ); + let gpu_chip = Sha256VmChipGpu::new( + tester.range_checker(), + tester.bitwise_op_lookup(), + tester.address_bits() as u32, + tester.timestamp_max_bits() as u32, + ); + + GpuTestChipHarness::with_capacity(executor, air, gpu_chip, cpu_chip, GPU_MAX_INS_CAPACITY) +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_sha256_tracegen() { + let mut rng = create_seeded_rng(); + let mut tester = + GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); + + let mut harness = create_cuda_harness(&tester); + + let num_ops = 70; + for i in 1..=num_ops { + set_and_execute( + &mut tester, + &mut harness.executor, + &mut harness.dense_arena, + &mut rng, + C::OPCODE, + None, + Some(i), + ); + } + + harness + .dense_arena + .get_record_seeker::() + .transfer_to_matrix_arena(&mut harness.matrix_arena); + + tester + .build() + .load_gpu_harness(harness) + .finalize() + .simple_test() + .unwrap(); +} + +#[cfg(feature = "cuda")] +#[test] +fn test_cuda_sha256_known_vectors() { + let mut rng = create_seeded_rng(); + let mut tester = + GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); + + let mut harness = create_cuda_harness(&tester); + + let test_vectors = [ + ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), + ( + "98c1c0bdb7d5fea9a88859f06c6c439f", + "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", + ), + ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), + ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") + ]; + + for (input, _) in test_vectors.iter() { + let input = Vec::from_hex(input).unwrap(); + + set_and_execute( + &mut tester, + &mut harness.executor, + &mut harness.dense_arena, + &mut rng, + SHA256, + Some(&input), + None, + ); + } + + harness + .dense_arena + .get_record_seeker::() + .transfer_to_matrix_arena(&mut harness.matrix_arena); + + tester + .build() + .load_gpu_harness(harness) + .finalize() + .simple_test() + .unwrap(); +} diff --git a/extensions/sha2/circuit/src/sha2_chips/trace.rs b/extensions/sha2/circuit/src/sha2_chips/trace.rs new file mode 100644 index 0000000000..8952288f42 --- /dev/null +++ b/extensions/sha2/circuit/src/sha2_chips/trace.rs @@ -0,0 +1,311 @@ +use std::{ + borrow::{Borrow, BorrowMut}, + mem::transmute, + slice::{from_raw_parts, from_raw_parts_mut}, +}; + +use openvm_circuit::{ + arch::{ + CustomBorrow, ExecutionError, MultiRowLayout, MultiRowMetadata, PreflightExecutor, + RecordArena, SizedRecord, VmStateMut, + }, + system::memory::{ + offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, + online::TracingMemory, + }, +}; +use openvm_circuit_primitives::AlignedBytesBorrow; +use openvm_instructions::{ + instruction::Instruction, + program::DEFAULT_PC_STEP, + riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, + LocalOpcode, +}; +use openvm_rv32im_circuit::adapters::{tracing_read, tracing_write}; +use openvm_sha2_air::{Sha256Config, Sha2BlockHasherSubairConfig, Sha2Variant, Sha512Config}; +use openvm_stark_backend::p3_field::PrimeField32; + +use crate::{ + Sha2Config, Sha2MainChipConfig, Sha2VmExecutor, SHA2_READ_SIZE, SHA2_REGISTER_READS, + SHA2_WRITE_SIZE, +}; + +#[derive(Clone, Copy)] +pub struct Sha2Metadata { + pub variant: Sha2Variant, +} + +impl MultiRowMetadata for Sha2Metadata { + #[inline(always)] + fn get_num_rows(&self) -> usize { + // The size of the record arena will be height * Sha2MainAir::width() * num_rows. + // We will not use the record arena's buffer for either chip's trace, so we just + // need to ensure that the record arena is large enough to store all the records. + // The size of Sha2RecordMut (in bytes) is less than Sha2MainAir::width() * size_of::(), + // for all SHA-2 variants. Therefore, we can set num_rows = 1. + 1 + } +} + +pub(crate) type Sha2RecordLayout = MultiRowLayout; + +#[repr(C)] +#[derive(AlignedBytesBorrow, Debug, Clone)] +pub struct Sha2RecordHeader { + pub variant: Sha2Variant, + pub from_pc: u32, + pub timestamp: u32, + pub dst_reg_ptr: u32, + pub state_reg_ptr: u32, + pub input_reg_ptr: u32, + pub dst_ptr: u32, + pub state_ptr: u32, + pub input_ptr: u32, + + pub register_reads_aux: [MemoryReadAuxRecord; SHA2_REGISTER_READS], +} + +pub struct Sha2RecordMut<'a> { + pub inner: &'a mut Sha2RecordHeader, + + pub message_bytes: &'a mut [u8], + pub prev_state: &'a mut [u8], + + pub input_reads_aux: &'a mut [MemoryReadAuxRecord], + pub state_reads_aux: &'a mut [MemoryReadAuxRecord], + pub write_aux: &'a mut [MemoryWriteBytesAuxRecord], +} + +impl<'a> CustomBorrow<'a, Sha2RecordMut<'a>, Sha2RecordLayout> for [u8] { + fn custom_borrow(&'a mut self, layout: Sha2RecordLayout) -> Sha2RecordMut<'a> { + // SAFETY: + // - Caller guarantees through the layout that self has sufficient length for all splits and + // constants are guaranteed <= self.len() by layout precondition + + let (header_slice, rest) = + unsafe { self.split_at_mut_unchecked(size_of::()) }; + let record_header: &mut Sha2RecordHeader = header_slice.borrow_mut(); + + let dims = Sha2PreComputeDims::new(layout.metadata.variant); + + let (message_bytes, rest) = unsafe { rest.split_at_mut_unchecked(dims.input_size) }; + let (prev_state, rest) = unsafe { rest.split_at_mut_unchecked(dims.state_size) }; + + let (input_reads_aux, rest) = unsafe { align_to_mut_at(rest, dims.input_reads) }; + let (state_reads_aux, rest) = unsafe { align_to_mut_at(rest, dims.state_reads) }; + let (write_aux, _) = unsafe { align_to_mut_at(rest, dims.state_writes) }; + + Sha2RecordMut { + inner: record_header, + message_bytes, + prev_state, + input_reads_aux, + state_reads_aux, + write_aux, + } + } + + unsafe fn extract_layout(&self) -> Sha2RecordLayout { + let (variant, _) = unsafe { align_to_at(self, 1) }; + let variant = variant[0]; + Sha2RecordLayout { + metadata: Sha2Metadata { variant }, + } + } +} + +unsafe fn align_to_mut_at(slice: &mut [u8], offset: usize) -> (&mut [T], &mut [u8]) { + let (_, items, rest) = unsafe { slice.align_to_mut::() }; + let (items, items_rest) = unsafe { items.split_at_mut_unchecked(offset) }; + let rest = unsafe { + let items_rest: &mut [u8] = transmute(items_rest); + from_raw_parts_mut( + items_rest.as_mut_ptr(), + items_rest.len() * size_of::() + rest.len(), + ) + }; + (items, rest) +} + +unsafe fn align_to_at(slice: &[u8], offset: usize) -> (&[T], &[u8]) { + let (_, items, rest) = unsafe { slice.align_to::() }; + let (items, items_rest) = unsafe { items.split_at_unchecked(offset) }; + let rest = unsafe { + let items_rest: &[u8] = transmute(items_rest); + from_raw_parts( + items_rest.as_ptr(), + items_rest.len() * size_of::() + rest.len(), + ) + }; + (items, rest) +} + +impl SizedRecord for Sha2RecordMut<'_> { + fn size(layout: &Sha2RecordLayout) -> usize { + let header_size = size_of::(); + let dims = Sha2PreComputeDims::new(layout.metadata.variant); + let mut total_len = header_size + + dims.input_size // input + + dims.state_size; // prev_state + + total_len = total_len.next_multiple_of(align_of::()); + total_len += dims.input_reads * size_of::(); + + total_len = total_len.next_multiple_of(align_of::()); + total_len += dims.state_reads * size_of::(); + + total_len = + total_len.next_multiple_of(align_of::>()); + total_len += dims.state_writes * size_of::>(); + + total_len + } + + fn alignment(_layout: &Sha2RecordLayout) -> usize { + // TODO: is this correct? + align_of::() + } +} + +// This is needed in CustomBorrow trait to convert the Sha2Variant that we read from the buffer +// into appropriate dimensions for the record. +struct Sha2PreComputeDims { + state_size: usize, + input_size: usize, + input_reads: usize, + state_reads: usize, + state_writes: usize, +} + +impl Sha2PreComputeDims { + fn new(variant: Sha2Variant) -> Self { + match variant { + Sha2Variant::Sha256 => Self { + state_size: Sha256Config::STATE_BYTES, + input_size: Sha256Config::BLOCK_BYTES, + input_reads: Sha256Config::BLOCK_READS, + state_reads: Sha256Config::STATE_READS, + state_writes: Sha256Config::STATE_WRITES, + }, + Sha2Variant::Sha512 => Self { + state_size: Sha512Config::STATE_BYTES, + input_size: Sha512Config::BLOCK_BYTES, + input_reads: Sha512Config::BLOCK_READS, + state_reads: Sha512Config::STATE_READS, + state_writes: Sha512Config::STATE_WRITES, + }, + Sha2Variant::Sha384 => unreachable!(), + } + } +} + +impl PreflightExecutor for Sha2VmExecutor +where + F: PrimeField32, + // for<'buf> RA: RecordArena<'buf, Sha2RecordLayout, Sha2RecordMut<'buf>>, + for<'buf> RA: RecordArena<'buf, Sha2RecordLayout, Sha2RecordMut<'buf>>, +{ + fn get_opcode_name(&self, _: usize) -> String { + format!("{:?}", C::OPCODE) + } + + fn execute( + &self, + state: VmStateMut, + instruction: &Instruction, + ) -> Result<(), ExecutionError> { + let &Instruction { + opcode, + a, + b, + c, + d, + e, + .. + } = instruction; + debug_assert_eq!(opcode, C::OPCODE.global_opcode()); + debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); + debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); + + let record = state.ctx.alloc(Sha2RecordLayout::new(Sha2Metadata { + variant: C::VARIANT, + })); + + record.inner.variant = C::VARIANT; + record.inner.from_pc = *state.pc; + record.inner.timestamp = state.memory.timestamp(); + record.inner.dst_reg_ptr = a.as_canonical_u32(); + record.inner.state_reg_ptr = b.as_canonical_u32(); + record.inner.input_reg_ptr = c.as_canonical_u32(); + + record.inner.dst_ptr = u32::from_le_bytes(tracing_read::( + state.memory, + RV32_REGISTER_AS, + record.inner.dst_reg_ptr, + &mut record.inner.register_reads_aux[0].prev_timestamp, + )); + record.inner.state_ptr = u32::from_le_bytes(tracing_read::( + state.memory, + RV32_REGISTER_AS, + record.inner.state_reg_ptr, + &mut record.inner.register_reads_aux[1].prev_timestamp, + )); + record.inner.input_ptr = u32::from_le_bytes(tracing_read::( + state.memory, + RV32_REGISTER_AS, + record.inner.input_reg_ptr, + &mut record.inner.register_reads_aux[2].prev_timestamp, + )); + + debug_assert!( + record.inner.dst_ptr as usize + C::STATE_BYTES <= (1 << self.pointer_max_bits) + ); + debug_assert!( + record.inner.state_ptr as usize + C::STATE_BYTES <= (1 << self.pointer_max_bits) + ); + debug_assert!( + record.inner.input_ptr as usize + C::BLOCK_BYTES <= (1 << self.pointer_max_bits) + ); + + for idx in 0..C::BLOCK_READS { + let read = tracing_read::( + state.memory, + RV32_MEMORY_AS, + record.inner.input_ptr + (idx * SHA2_READ_SIZE) as u32, + &mut record.input_reads_aux[idx].prev_timestamp, + ); + record.message_bytes[idx * SHA2_READ_SIZE..(idx + 1) * SHA2_READ_SIZE] + .copy_from_slice(&read); + } + + for idx in 0..C::STATE_READS { + let read = tracing_read::( + state.memory, + RV32_MEMORY_AS, + record.inner.state_ptr + (idx * SHA2_READ_SIZE) as u32, + &mut record.state_reads_aux[idx].prev_timestamp, + ); + record.prev_state[idx * SHA2_READ_SIZE..(idx + 1) * SHA2_READ_SIZE] + .copy_from_slice(&read); + } + + let mut hash_state = record.prev_state.to_vec(); + C::compress(&mut hash_state, record.message_bytes); + + for idx in 0..C::STATE_WRITES { + tracing_write::( + state.memory, + RV32_MEMORY_AS, + record.inner.dst_ptr + (idx * SHA2_WRITE_SIZE) as u32, + hash_state[idx * SHA2_WRITE_SIZE..(idx + 1) * SHA2_WRITE_SIZE] + .try_into() + .unwrap(), + &mut record.write_aux[idx].prev_timestamp, + &mut record.write_aux[idx].prev_data, + ); + } + + *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); + Ok(()) + } +} diff --git a/extensions/sha256/guest/Cargo.toml b/extensions/sha2/guest/Cargo.toml similarity index 69% rename from extensions/sha256/guest/Cargo.toml rename to extensions/sha2/guest/Cargo.toml index e9d28292b8..1c6503002e 100644 --- a/extensions/sha256/guest/Cargo.toml +++ b/extensions/sha2/guest/Cargo.toml @@ -1,9 +1,9 @@ [package] -name = "openvm-sha256-guest" +name = "openvm-sha2-guest" version.workspace = true authors.workspace = true edition.workspace = true -description = "Guest extension for Sha256" +description = "Guest extension for SHA-2" [dependencies] openvm-platform = { workspace = true } diff --git a/extensions/sha2/guest/src/lib.rs b/extensions/sha2/guest/src/lib.rs new file mode 100644 index 0000000000..8ef93d136e --- /dev/null +++ b/extensions/sha2/guest/src/lib.rs @@ -0,0 +1,147 @@ +#![no_std] + +#[cfg(target_os = "zkvm")] +use openvm_platform::alloc::AlignedBuf; + +/// This is custom-0 defined in RISC-V spec document +pub const OPCODE: u8 = 0x0b; +pub const SHA2_FUNCT3: u8 = 0b100; + +// There is no Sha384 enum variant because the SHA-384 compression function is +// the same as the SHA-512 compression function. +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +#[repr(u8)] +pub enum Sha2BaseFunct7 { + Sha256 = 0x1, + Sha512 = 0x2, +} + +/// zkvm native implementation of sha256 compression function +/// # Safety +/// +/// The VM accepts the previous hash state and the next block of input, and writes the +/// new hash state. +/// - `prev_state` must point to a buffer of at least 32 bytes +/// - `input` must point to a buffer of at least 64 bytes +/// - `output` must point to a buffer of at least 32 bytes +/// +/// [`sha2-256`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha256_impl(prev_state: *const u8, input: *const u8, output: *mut u8) { + // SAFETY: we handle all cases where `prev_state`, `input`, or `output` are not aligned to 4 + // bytes. + + // The minimum alignment required for the buffers + const MIN_ALIGN: usize = 4; + unsafe { + let prev_state_is_aligned = prev_state as usize % MIN_ALIGN == 0; + let input_is_aligned = input as usize % MIN_ALIGN == 0; + let output_is_aligned = output as usize % MIN_ALIGN == 0; + + let prev_state_ptr = if prev_state_is_aligned { + prev_state + } else { + AlignedBuf::new(prev_state, 32, MIN_ALIGN).ptr + }; + + let input_ptr = if input_is_aligned { + input + } else { + AlignedBuf::new(input, 64, MIN_ALIGN).ptr + }; + + let output_ptr = if output_is_aligned { + output + } else { + AlignedBuf::uninit(32, MIN_ALIGN).ptr + }; + + __native_sha256_compress(prev_state_ptr, input_ptr, output_ptr); + + if !output_is_aligned { + core::ptr::copy_nonoverlapping(output_ptr, output, 32); + } + } +} + +/// zkvm native implementation of sha512 compression function +/// # Safety +/// +/// The VM accepts the previous hash state and the next block of input, and writes the +/// new hash state. +/// - `prev_state` must point to a buffer of at least 64 bytes +/// - `input` must point to a buffer of at least 128 bytes +/// - `output` must point to a buffer of at least 64 bytes +/// +/// [`sha2-512`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf +#[cfg(target_os = "zkvm")] +#[inline(always)] +#[no_mangle] +pub extern "C" fn zkvm_sha512_impl(prev_state: *const u8, input: *const u8, output: *mut u8) { + // SAFETY: we handle all cases where `prev_state`, `input`, or `output` are not aligned to 4 + // bytes. + + // The minimum alignment required for the buffers + const MIN_ALIGN: usize = 4; + unsafe { + let prev_state_is_aligned = prev_state as usize % MIN_ALIGN == 0; + let input_is_aligned = input as usize % MIN_ALIGN == 0; + let output_is_aligned = output as usize % MIN_ALIGN == 0; + + let prev_state_ptr = if prev_state_is_aligned { + prev_state + } else { + AlignedBuf::new(prev_state, 64, MIN_ALIGN).ptr + }; + + let input_ptr = if input_is_aligned { + input + } else { + AlignedBuf::new(input, 128, MIN_ALIGN).ptr + }; + + let output_ptr = if output_is_aligned { + output + } else { + AlignedBuf::uninit(64, MIN_ALIGN).ptr + }; + + __native_sha512_compress(prev_state_ptr, input_ptr, output_ptr); + + if !output_is_aligned { + core::ptr::copy_nonoverlapping(output_ptr, output, 64); + } + } +} + +/// sha256 compression function intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the previous hash state and the next block of input, and writes the +/// 32-byte hash. +/// - `prev_state` must point to a buffer of at least 32 bytes +/// - `input` must point to a buffer of at least 64 bytes +/// - `output` must point to a buffer of at least 32 bytes +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha256_compress(prev_state: *const u8, input: *const u8, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha256 as u8, rd = In output, rs1 = In prev_state, rs2 = In input); +} + +/// sha512 intrinsic binding +/// +/// # Safety +/// +/// The VM accepts the previous hash state and the next block of input, and writes the +/// 64-byte hash. +/// - `prev_state` must point to a buffer of at least 64 bytes +/// - `input` must point to a buffer of at least 128 bytes +/// - `output` must point to a buffer of at least 64 bytes +#[cfg(target_os = "zkvm")] +#[inline(always)] +fn __native_sha512_compress(prev_state: *const u8, input: *const u8, output: *mut u8) { + openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA2_FUNCT3, funct7 = Sha2BaseFunct7::Sha512 as u8, rd = In output, rs1 = In prev_state, rs2 = In input); +} diff --git a/extensions/sha256/transpiler/Cargo.toml b/extensions/sha2/transpiler/Cargo.toml similarity index 73% rename from extensions/sha256/transpiler/Cargo.toml rename to extensions/sha2/transpiler/Cargo.toml index 933859f3a8..9eff76a3db 100644 --- a/extensions/sha256/transpiler/Cargo.toml +++ b/extensions/sha2/transpiler/Cargo.toml @@ -1,15 +1,15 @@ [package] -name = "openvm-sha256-transpiler" +name = "openvm-sha2-transpiler" version.workspace = true authors.workspace = true edition.workspace = true -description = "Transpiler extension for sha256" +description = "Transpiler extension for SHA-2" [dependencies] openvm-stark-backend = { workspace = true } openvm-instructions = { workspace = true } openvm-transpiler = { workspace = true } rrs-lib = { workspace = true } -openvm-sha256-guest = { workspace = true } +openvm-sha2-guest = { workspace = true } openvm-instructions-derive = { workspace = true } strum = { workspace = true } diff --git a/extensions/sha2/transpiler/src/lib.rs b/extensions/sha2/transpiler/src/lib.rs new file mode 100644 index 0000000000..0110d869c3 --- /dev/null +++ b/extensions/sha2/transpiler/src/lib.rs @@ -0,0 +1,58 @@ +use openvm_instructions::{riscv::RV32_MEMORY_AS, LocalOpcode}; +use openvm_instructions_derive::LocalOpcode; +use openvm_sha2_guest::{Sha2BaseFunct7, OPCODE, SHA2_FUNCT3}; +use openvm_stark_backend::p3_field::PrimeField32; +use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput}; +use rrs_lib::instruction_formats::RType; +use strum::{EnumCount, EnumIter, FromRepr}; + +// There is no SHA384 opcode because the SHA-384 compression function is +// the same as the SHA-512 compression function. +#[derive( + Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, +)] +#[opcode_offset = 0x320] +#[repr(usize)] +pub enum Rv32Sha2Opcode { + SHA256, + SHA512, +} + +#[derive(Default)] +pub struct Sha2TranspilerExtension; + +impl TranspilerExtension for Sha2TranspilerExtension { + fn process_custom(&self, instruction_stream: &[u32]) -> Option> { + if instruction_stream.is_empty() { + return None; + } + let instruction_u32 = instruction_stream[0]; + let opcode = (instruction_u32 & 0x7f) as u8; + let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; + + if (opcode, funct3) != (OPCODE, SHA2_FUNCT3) { + return None; + } + let dec_insn = RType::new(instruction_u32); + + if dec_insn.funct7 == Sha2BaseFunct7::Sha256 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA256.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else if dec_insn.funct7 == Sha2BaseFunct7::Sha512 as u32 { + let instruction = from_r_type( + Rv32Sha2Opcode::SHA512.global_opcode().as_usize(), + RV32_MEMORY_AS as usize, + &dec_insn, + true, + ); + Some(TranspilerOutput::one_to_one(instruction)) + } else { + None + } + } +} diff --git a/extensions/sha256/circuit/src/extension/mod.rs b/extensions/sha256/circuit/src/extension/mod.rs deleted file mode 100644 index e14079ac0e..0000000000 --- a/extensions/sha256/circuit/src/extension/mod.rs +++ /dev/null @@ -1,138 +0,0 @@ -use std::{result::Result, sync::Arc}; - -use derive_more::derive::From; -use openvm_circuit::{ - arch::{ - AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, - ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension, - VmExecutionExtension, VmProverExtension, - }, - system::memory::SharedMemoryHelper, -}; -use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, - SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::*; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::PrimeField32, - prover::cpu::{CpuBackend, CpuDevice}, -}; -use openvm_stark_sdk::engine::StarkEngine; -use serde::{Deserialize, Serialize}; -use strum::IntoEnumIterator; - -use crate::*; - -cfg_if::cfg_if! { - if #[cfg(feature = "cuda")] { - mod cuda; - pub use self::cuda::*; - pub use self::cuda::Sha256GpuProverExt as Sha256ProverExt; - } else { - pub use self::Sha2CpuProverExt as Sha256ProverExt; - } -} - -// =================================== VM Extension Implementation ================================= -#[derive(Clone, Copy, Debug, Default, Serialize, Deserialize)] -pub struct Sha256; - -#[derive(Clone, From, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)] -pub enum Sha256Executor { - Sha256(Sha256VmExecutor), -} - -impl VmExecutionExtension for Sha256 { - type Executor = Sha256Executor; - - fn extend_execution( - &self, - inventory: &mut ExecutorInventoryBuilder, - ) -> Result<(), ExecutorInventoryError> { - let pointer_max_bits = inventory.pointer_max_bits(); - let sha256_step = Sha256VmExecutor::new(Rv32Sha256Opcode::CLASS_OFFSET, pointer_max_bits); - inventory.add_executor( - sha256_step, - Rv32Sha256Opcode::iter().map(|x| x.global_opcode()), - )?; - - Ok(()) - } -} - -impl VmCircuitExtension for Sha256 { - fn extend_circuit(&self, inventory: &mut AirInventory) -> Result<(), AirInventoryError> { - let pointer_max_bits = inventory.pointer_max_bits(); - - let bitwise_lu = { - let existing_air = inventory.find_air::>().next(); - if let Some(air) = existing_air { - air.bus - } else { - let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx()); - let air = BitwiseOperationLookupAir::<8>::new(bus); - inventory.add_air(air); - air.bus - } - }; - - let sha256 = Sha256VmAir::new( - inventory.system().port(), - bitwise_lu, - pointer_max_bits, - inventory.new_bus_idx(), - ); - inventory.add_air(sha256); - - Ok(()) - } -} - -pub struct Sha2CpuProverExt; -// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker, -// BitwiseOperationLookupChip) are specific to CpuBackend. -impl VmProverExtension for Sha2CpuProverExt -where - SC: StarkGenericConfig, - E: StarkEngine, PD = CpuDevice>, - RA: RowMajorMatrixArena>, - Val: PrimeField32, -{ - fn extend_prover( - &self, - _: &Sha256, - inventory: &mut ChipInventory>, - ) -> Result<(), ChipInventoryError> { - let range_checker = inventory.range_checker()?.clone(); - let timestamp_max_bits = inventory.timestamp_max_bits(); - let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits); - let pointer_max_bits = inventory.airs().pointer_max_bits(); - - let bitwise_lu = { - let existing_chip = inventory - .find_chip::>() - .next(); - if let Some(chip) = existing_chip { - chip.clone() - } else { - let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?; - let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus)); - inventory.add_periphery_chip(chip.clone()); - chip - } - }; - - inventory.next_air::()?; - let sha256 = Sha256VmChip::new( - Sha256VmFiller::new(bitwise_lu, pointer_max_bits), - mem_helper, - ); - inventory.add_executor_chip(sha256); - - Ok(()) - } -} diff --git a/extensions/sha256/circuit/src/lib.rs b/extensions/sha256/circuit/src/lib.rs deleted file mode 100644 index 2847e51636..0000000000 --- a/extensions/sha256/circuit/src/lib.rs +++ /dev/null @@ -1,158 +0,0 @@ -#![cfg_attr(feature = "tco", allow(incomplete_features))] -#![cfg_attr(feature = "tco", feature(explicit_tail_calls))] -#![cfg_attr(feature = "tco", feature(core_intrinsics))] - -use std::result::Result; - -use openvm_circuit::{ - arch::{ - AirInventory, ChipInventoryError, InitFileGenerator, MatrixRecordArena, SystemConfig, - VmBuilder, VmChipComplex, VmProverExtension, - }, - system::{SystemChipInventory, SystemCpuBuilder, SystemExecutor}, -}; -use openvm_circuit_derive::VmConfig; -use openvm_rv32im_circuit::{ - Rv32I, Rv32IExecutor, Rv32ImCpuProverExt, Rv32Io, Rv32IoExecutor, Rv32M, Rv32MExecutor, -}; -use openvm_stark_backend::{ - config::{StarkGenericConfig, Val}, - p3_field::PrimeField32, - prover::cpu::{CpuBackend, CpuDevice}, -}; -use openvm_stark_sdk::engine::StarkEngine; -use serde::{Deserialize, Serialize}; - -mod sha256_chip; -pub use sha256_chip::*; - -mod extension; -pub use extension::*; - -cfg_if::cfg_if! { - if #[cfg(feature = "cuda")] { - use openvm_circuit::arch::DenseRecordArena; - use openvm_circuit::system::cuda::{extensions::SystemGpuBuilder, SystemChipInventoryGPU}; - use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine, prover_backend::GpuBackend}; - use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Config; - use openvm_rv32im_circuit::Rv32ImGpuProverExt; - pub(crate) mod cuda_abi; - pub use Sha256Rv32GpuBuilder as Sha256Rv32Builder; - } else { - pub use Sha256Rv32CpuBuilder as Sha256Rv32Builder; - } -} - -#[derive(Clone, Debug, VmConfig, derive_new::new, Serialize, Deserialize)] -pub struct Sha256Rv32Config { - #[config(executor = "SystemExecutor")] - pub system: SystemConfig, - #[extension] - pub rv32i: Rv32I, - #[extension] - pub rv32m: Rv32M, - #[extension] - pub io: Rv32Io, - #[extension] - pub sha256: Sha256, -} - -impl Default for Sha256Rv32Config { - fn default() -> Self { - Self { - system: SystemConfig::default(), - rv32i: Rv32I, - rv32m: Rv32M::default(), - io: Rv32Io, - sha256: Sha256, - } - } -} - -// Default implementation uses no init file -impl InitFileGenerator for Sha256Rv32Config {} - -#[derive(Clone)] -pub struct Sha256Rv32CpuBuilder; - -impl VmBuilder for Sha256Rv32CpuBuilder -where - SC: StarkGenericConfig, - E: StarkEngine, PD = CpuDevice>, - Val: PrimeField32, -{ - type VmConfig = Sha256Rv32Config; - type SystemChipInventory = SystemChipInventory; - type RecordArena = MatrixRecordArena>; - - fn create_chip_complex( - &self, - config: &Sha256Rv32Config, - circuit: AirInventory, - ) -> Result< - VmChipComplex, - ChipInventoryError, - > { - let mut chip_complex = - VmBuilder::::create_chip_complex(&SystemCpuBuilder, &config.system, circuit)?; - let inventory = &mut chip_complex.inventory; - VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32i, inventory)?; - VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.rv32m, inventory)?; - VmProverExtension::::extend_prover(&Rv32ImCpuProverExt, &config.io, inventory)?; - VmProverExtension::::extend_prover(&Sha2CpuProverExt, &config.sha256, inventory)?; - Ok(chip_complex) - } -} - -#[cfg(feature = "cuda")] -#[derive(Clone)] -pub struct Sha256Rv32GpuBuilder; - -#[cfg(feature = "cuda")] -impl VmBuilder for Sha256Rv32GpuBuilder { - type VmConfig = Sha256Rv32Config; - type SystemChipInventory = SystemChipInventoryGPU; - type RecordArena = DenseRecordArena; - - fn create_chip_complex( - &self, - config: &Sha256Rv32Config, - circuit: AirInventory, - ) -> Result< - VmChipComplex< - BabyBearPoseidon2Config, - Self::RecordArena, - GpuBackend, - Self::SystemChipInventory, - >, - ChipInventoryError, - > { - let mut chip_complex = VmBuilder::::create_chip_complex( - &SystemGpuBuilder, - &config.system, - circuit, - )?; - let inventory = &mut chip_complex.inventory; - VmProverExtension::::extend_prover( - &Rv32ImGpuProverExt, - &config.rv32i, - inventory, - )?; - VmProverExtension::::extend_prover( - &Rv32ImGpuProverExt, - &config.rv32m, - inventory, - )?; - VmProverExtension::::extend_prover( - &Rv32ImGpuProverExt, - &config.io, - inventory, - )?; - VmProverExtension::::extend_prover( - &Sha256GpuProverExt, - &config.sha256, - inventory, - )?; - Ok(chip_complex) - } -} diff --git a/extensions/sha256/circuit/src/sha256_chip/air.rs b/extensions/sha256/circuit/src/sha256_chip/air.rs deleted file mode 100644 index 2fe1cb26c0..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/air.rs +++ /dev/null @@ -1,624 +0,0 @@ -use std::{array, borrow::Borrow, cmp::min}; - -use openvm_circuit::{ - arch::ExecutionBridge, - system::{ - memory::{offline_checker::MemoryBridge, MemoryAddress}, - SystemPort, - }, -}; -use openvm_circuit_primitives::{ - bitwise_op_lookup::BitwiseOperationLookupBus, encoder::Encoder, utils::not, SubAir, -}; -use openvm_instructions::{ - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, - LocalOpcode, -}; -use openvm_sha256_air::{ - compose, Sha256Air, SHA256_BLOCK_U8S, SHA256_HASH_WORDS, SHA256_ROUNDS_PER_ROW, - SHA256_WORD_U16S, SHA256_WORD_U8S, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - interaction::{BusIndex, InteractionBuilder}, - p3_air::{Air, AirBuilder, BaseAir}, - p3_field::{Field, FieldAlgebra}, - p3_matrix::Matrix, - rap::{BaseAirWithPublicValues, PartitionedBaseAir}, -}; - -use super::{ - Sha256VmDigestCols, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, SHA256VM_DIGEST_WIDTH, - SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, SHA256_READ_SIZE, -}; - -/// Sha256VmAir does all constraints related to message padding and -/// the Sha256Air subair constrains the actual hash -#[derive(Clone, Debug)] -pub struct Sha256VmAir { - pub execution_bridge: ExecutionBridge, - pub memory_bridge: MemoryBridge, - /// Bus to send byte checks to - pub bitwise_lookup_bus: BitwiseOperationLookupBus, - /// Maximum number of bits allowed for an address pointer - /// Must be at least 24 - pub ptr_max_bits: usize, - pub(super) sha256_subair: Sha256Air, - pub(super) padding_encoder: Encoder, -} - -impl Sha256VmAir { - pub fn new( - SystemPort { - execution_bus, - program_bus, - memory_bridge, - }: SystemPort, - bitwise_lookup_bus: BitwiseOperationLookupBus, - ptr_max_bits: usize, - self_bus_idx: BusIndex, - ) -> Self { - Self { - execution_bridge: ExecutionBridge::new(execution_bus, program_bus), - memory_bridge, - bitwise_lookup_bus, - ptr_max_bits, - sha256_subair: Sha256Air::new(bitwise_lookup_bus, self_bus_idx), - padding_encoder: Encoder::new(PaddingFlags::COUNT, 2, false), - } - } -} - -impl BaseAirWithPublicValues for Sha256VmAir {} -impl PartitionedBaseAir for Sha256VmAir {} -impl BaseAir for Sha256VmAir { - fn width(&self) -> usize { - SHA256VM_WIDTH - } -} - -impl Air for Sha256VmAir { - fn eval(&self, builder: &mut AB) { - self.eval_padding(builder); - self.eval_transitions(builder); - self.eval_reads(builder); - self.eval_last_row(builder); - - self.sha256_subair.eval(builder, SHA256VM_CONTROL_WIDTH); - } -} - -#[allow(dead_code, non_camel_case_types)] -pub(super) enum PaddingFlags { - /// Not considered for padding - W's are not constrained - NotConsidered, - /// Not padding - W's should be equal to the message - NotPadding, - /// FIRST_PADDING_i: it is the first row with padding and there are i cells of non-padding - FirstPadding0, - FirstPadding1, - FirstPadding2, - FirstPadding3, - FirstPadding4, - FirstPadding5, - FirstPadding6, - FirstPadding7, - FirstPadding8, - FirstPadding9, - FirstPadding10, - FirstPadding11, - FirstPadding12, - FirstPadding13, - FirstPadding14, - FirstPadding15, - /// FIRST_PADDING_i_LastRow: it is the first row with padding and there are i cells of - /// non-padding AND it is the last reading row of the message - /// NOTE: if the Last row has padding it has to be at least 9 cells since the last 8 cells are - /// padded with the message length - FirstPadding0_LastRow, - FirstPadding1_LastRow, - FirstPadding2_LastRow, - FirstPadding3_LastRow, - FirstPadding4_LastRow, - FirstPadding5_LastRow, - FirstPadding6_LastRow, - FirstPadding7_LastRow, - /// The entire row is padding AND it is not the first row with padding - /// AND it is the 4th row of the last block of the message - EntirePaddingLastRow, - /// The entire row is padding AND it is not the first row with padding - EntirePadding, -} - -impl PaddingFlags { - /// The number of padding flags (including NotConsidered) - pub const COUNT: usize = EntirePadding as usize + 1; -} - -use PaddingFlags::*; -impl Sha256VmAir { - /// Implement all necessary constraints for the padding - fn eval_padding(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - let next_cols: &Sha256VmRoundCols = next[..SHA256VM_ROUND_WIDTH].borrow(); - - // Constrain the sanity of the padding flags - self.padding_encoder - .eval(builder, &local_cols.control.pad_flags); - - builder.assert_one(self.padding_encoder.contains_flag_range::( - &local_cols.control.pad_flags, - NotConsidered as usize..=EntirePadding as usize, - )); - - Self::eval_padding_transitions(self, builder, local_cols, next_cols); - Self::eval_padding_row(self, builder, local_cols); - } - - fn eval_padding_transitions( - &self, - builder: &mut AB, - local: &Sha256VmRoundCols, - next: &Sha256VmRoundCols, - ) { - let next_is_last_row = next.inner.flags.is_digest_row * next.inner.flags.is_last_block; - - // Constrain that `padding_occured` is 1 on a suffix of rows in each message, excluding the - // last digest row, and 0 everywhere else. Furthermore, the suffix starts in the - // first 4 rows of some block. - - builder.assert_bool(local.control.padding_occurred); - // Last round row in the last block has padding_occurred = 1 - // This is the end of the suffix - builder - .when(next_is_last_row.clone()) - .assert_one(local.control.padding_occurred); - - // Digest row in the last block has padding_occurred = 0 - builder - .when(next_is_last_row.clone()) - .assert_zero(next.control.padding_occurred); - - // If padding_occurred = 1 in the current row, then padding_occurred = 1 in the next row, - // unless next is the last digest row - builder - .when(local.control.padding_occurred - next_is_last_row.clone()) - .assert_one(next.control.padding_occurred); - - // If next row is not first 4 rows of a block, then next.padding_occurred = - // local.padding_occurred. So padding_occurred only changes in the first 4 rows of a - // block. - builder - .when_transition() - .when(not(next.inner.flags.is_first_4_rows) - next_is_last_row) - .assert_eq( - next.control.padding_occurred, - local.control.padding_occurred, - ); - - // Constrain the that the start of the padding is correct - let next_is_first_padding_row = - next.control.padding_occurred - local.control.padding_occurred; - // Row index if its between 0..4, else 0 - let next_row_idx = self.sha256_subair.row_idx_encoder.flag_with_val::( - &next.inner.flags.row_idx, - &(0..4).map(|x| (x, x)).collect::>(), - ); - // How many non-padding cells there are in the next row. - // Will be 0 on non-padding rows. - let next_padding_offset = self.padding_encoder.flag_with_val::( - &next.control.pad_flags, - &(0..16) - .map(|i| (FirstPadding0 as usize + i, i)) - .collect::>(), - ) + self.padding_encoder.flag_with_val::( - &next.control.pad_flags, - &(0..8) - .map(|i| (FirstPadding0_LastRow as usize + i, i)) - .collect::>(), - ); - - // Will be 0 on last digest row since: - // - padding_occurred = 0 is constrained above - // - next_row_idx = 0 since row_idx is not in 0..4 - // - and next_padding_offset = 0 since `pad_flags = NotConsidered` - let expected_len = next.inner.flags.local_block_idx - * next.control.padding_occurred - * AB::Expr::from_canonical_usize(SHA256_BLOCK_U8S) - + next_row_idx * AB::Expr::from_canonical_usize(SHA256_READ_SIZE) - + next_padding_offset; - - // Note: `next_is_first_padding_row` is either -1,0,1 - // If 1, then this constrains the length of message - // If -1, then `next` must be the last digest row and so this constraint will be 0 == 0 - builder.when(next_is_first_padding_row).assert_eq( - expected_len, - next.control.len * next.control.padding_occurred, - ); - - // Constrain the padding flags are of correct type (eg is not padding or first padding) - let is_next_first_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - FirstPadding0 as usize..=FirstPadding7_LastRow as usize, - ); - - let is_next_last_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, - ); - - let is_next_entire_padding = self.padding_encoder.contains_flag_range::( - &next.control.pad_flags, - EntirePaddingLastRow as usize..=EntirePadding as usize, - ); - - let is_next_not_considered = self - .padding_encoder - .contains_flag::(&next.control.pad_flags, &[NotConsidered as usize]); - - let is_next_not_padding = self - .padding_encoder - .contains_flag::(&next.control.pad_flags, &[NotPadding as usize]); - - let is_next_4th_row = self - .sha256_subair - .row_idx_encoder - .contains_flag::(&next.inner.flags.row_idx, &[3]); - - // `pad_flags` is `NotConsidered` on all rows except the first 4 rows of a block - builder.assert_eq( - not(next.inner.flags.is_first_4_rows), - is_next_not_considered, - ); - - // `pad_flags` is `EntirePadding` if the previous row is padding - builder.when(next.inner.flags.is_first_4_rows).assert_eq( - local.control.padding_occurred * next.control.padding_occurred, - is_next_entire_padding, - ); - - // `pad_flags` is `FirstPadding*` if current row is padding and the previous row is not - // padding - builder.when(next.inner.flags.is_first_4_rows).assert_eq( - not(local.control.padding_occurred) * next.control.padding_occurred, - is_next_first_padding, - ); - - // `pad_flags` is `NotPadding` if current row is not padding - builder - .when(next.inner.flags.is_first_4_rows) - .assert_eq(not(next.control.padding_occurred), is_next_not_padding); - - // `pad_flags` is `*LastRow` on the row that contains the last four words of the message - builder - .when(next.inner.flags.is_last_block) - .assert_eq(is_next_4th_row, is_next_last_padding); - } - - fn eval_padding_row( - &self, - builder: &mut AB, - local: &Sha256VmRoundCols, - ) { - let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| { - local.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U8S)] - [i % (SHA256_WORD_U8S)] - }); - - let get_ith_byte = |i: usize| { - let word_idx = i / SHA256_ROUNDS_PER_ROW; - let word = local.inner.message_schedule.w[word_idx].map(|x| x.into()); - // Need to reverse the byte order to match the endianness of the memory - let byte_idx = 4 - i % 4 - 1; - compose::(&word[byte_idx * 8..(byte_idx + 1) * 8], 1) - }; - - let is_not_padding = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[NotPadding as usize]); - - // Check the `w`s on case by case basis - for (i, message_byte) in message.iter().enumerate() { - let w = get_ith_byte(i); - let should_be_message = is_not_padding.clone() - + if i < 15 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0 as usize + i + 1..=FirstPadding15 as usize, - ) - } else { - AB::Expr::ZERO - } - + if i < 7 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize + i + 1..=FirstPadding7_LastRow as usize, - ) - } else { - AB::Expr::ZERO - }; - builder - .when(should_be_message) - .assert_eq(w.clone(), *message_byte); - - let should_be_zero = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[EntirePadding as usize]) - + if i < 12 { - self.padding_encoder.contains_flag::( - &local.control.pad_flags, - &[EntirePaddingLastRow as usize], - ) + if i > 0 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize - ..=min( - FirstPadding0_LastRow as usize + i - 1, - FirstPadding7_LastRow as usize, - ), - ) - } else { - AB::Expr::ZERO - } - } else { - AB::Expr::ZERO - } - + if i > 0 { - self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0 as usize..=FirstPadding0 as usize + i - 1, - ) - } else { - AB::Expr::ZERO - }; - builder.when(should_be_zero).assert_zero(w.clone()); - - // Assumes bit-length of message is a multiple of 8 (message is bytes) - // This is true because the message is given as &[u8] - let should_be_128 = self - .padding_encoder - .contains_flag::(&local.control.pad_flags, &[FirstPadding0 as usize + i]) - + if i < 8 { - self.padding_encoder.contains_flag::( - &local.control.pad_flags, - &[FirstPadding0_LastRow as usize + i], - ) - } else { - AB::Expr::ZERO - }; - - builder - .when(should_be_128) - .assert_eq(AB::Expr::from_canonical_u32(1 << 7), w); - - // should be len is handled outside of the loop - } - let appended_len = compose::( - &[ - get_ith_byte(15), - get_ith_byte(14), - get_ith_byte(13), - get_ith_byte(12), - ], - RV32_CELL_BITS, - ); - - let actual_len = local.control.len; - - let is_last_padding_row = self.padding_encoder.contains_flag_range::( - &local.control.pad_flags, - FirstPadding0_LastRow as usize..=EntirePaddingLastRow as usize, - ); - - builder.when(is_last_padding_row.clone()).assert_eq( - appended_len * AB::F::from_canonical_usize(RV32_CELL_BITS).inverse(), // bit to byte conversion - actual_len, - ); - - // We constrain that the appended length is in bytes - builder.when(is_last_padding_row.clone()).assert_zero( - local.inner.message_schedule.w[3][0] - + local.inner.message_schedule.w[3][1] - + local.inner.message_schedule.w[3][2], - ); - - // We can't support messages longer than 2^30 bytes because the length has to fit in a field - // element. So, constrain that the first 4 bytes of the length are 0. - // Thus, the bit-length is < 2^32 so the message is < 2^29 bytes. - for i in 8..12 { - builder - .when(is_last_padding_row.clone()) - .assert_zero(get_ith_byte(i)); - } - } - /// Implement constraints on `len`, `read_ptr` and `cur_timestamp` - fn eval_transitions(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - let next_cols: &Sha256VmRoundCols = next[..SHA256VM_ROUND_WIDTH].borrow(); - - let is_last_row = - local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row; - - // Len should be the same for the entire message - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq(next_cols.control.len, local_cols.control.len); - - // Read ptr should increment by [SHA256_READ_SIZE] for the first 4 rows and stay the same - // otherwise - let read_ptr_delta = local_cols.inner.flags.is_first_4_rows - * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq( - next_cols.control.read_ptr, - local_cols.control.read_ptr + read_ptr_delta, - ); - - // Timestamp should increment by 1 for the first 4 rows and stay the same otherwise - let timestamp_delta = local_cols.inner.flags.is_first_4_rows * AB::Expr::ONE; - builder - .when_transition() - .when(not::(is_last_row.clone())) - .assert_eq( - next_cols.control.cur_timestamp, - local_cols.control.cur_timestamp + timestamp_delta, - ); - } - - /// Implement the reads for the first 4 rows of a block - fn eval_reads(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_cols: &Sha256VmRoundCols = local[..SHA256VM_ROUND_WIDTH].borrow(); - - let message: [AB::Var; SHA256_READ_SIZE] = array::from_fn(|i| { - local_cols.inner.message_schedule.carry_or_buffer[i / (SHA256_WORD_U16S * 2)] - [i % (SHA256_WORD_U16S * 2)] - }); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - local_cols.control.read_ptr, - ), - message, - local_cols.control.cur_timestamp, - &local_cols.read_aux, - ) - .eval(builder, local_cols.inner.flags.is_first_4_rows); - } - /// Implement the constraints for the last row of a message - fn eval_last_row(&self, builder: &mut AB) { - let main = builder.main(); - let local = main.row_slice(0); - let local_cols: &Sha256VmDigestCols = local[..SHA256VM_DIGEST_WIDTH].borrow(); - - let timestamp: AB::Var = local_cols.from_state.timestamp; - let mut timestamp_delta: usize = 0; - let mut timestamp_pp = || { - timestamp_delta += 1; - timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1) - }; - - let is_last_row = - local_cols.inner.flags.is_last_block * local_cols.inner.flags.is_digest_row; - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rd_ptr, - ), - local_cols.dst_ptr, - timestamp_pp(), - &local_cols.register_reads_aux[0], - ) - .eval(builder, is_last_row.clone()); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rs1_ptr, - ), - local_cols.src_ptr, - timestamp_pp(), - &local_cols.register_reads_aux[1], - ) - .eval(builder, is_last_row.clone()); - - self.memory_bridge - .read( - MemoryAddress::new( - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - local_cols.rs2_ptr, - ), - local_cols.len_data, - timestamp_pp(), - &local_cols.register_reads_aux[2], - ) - .eval(builder, is_last_row.clone()); - - // range check that the memory pointers don't overflow - // Note: no need to range check the length since we read from memory step by step and - // the memory bus will catch any memory accesses beyond ptr_max_bits - let shift = AB::Expr::from_canonical_usize( - 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.ptr_max_bits), - ); - // This only works if self.ptr_max_bits >= 24 which is typically the case - self.bitwise_lookup_bus - .send_range( - // It is fine to shift like this since we already know that dst_ptr and src_ptr - // have [RV32_CELL_BITS] bits - local_cols.dst_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), - local_cols.src_ptr[RV32_REGISTER_NUM_LIMBS - 1] * shift.clone(), - ) - .eval(builder, is_last_row.clone()); - - // the number of reads that happened to read the entire message: we do 4 reads per block - let time_delta = (local_cols.inner.flags.local_block_idx + AB::Expr::ONE) - * AB::Expr::from_canonical_usize(4); - // Every time we read the message we increment the read pointer by SHA256_READ_SIZE - let read_ptr_delta = time_delta.clone() * AB::Expr::from_canonical_usize(SHA256_READ_SIZE); - - let result: [AB::Var; SHA256_WORD_U8S * SHA256_HASH_WORDS] = array::from_fn(|i| { - // The limbs are written in big endian order to the memory so need to be reversed - local_cols.inner.final_hash[i / SHA256_WORD_U8S] - [SHA256_WORD_U8S - i % SHA256_WORD_U8S - 1] - }); - - let dst_ptr_val = - compose::(&local_cols.dst_ptr.map(|x| x.into()), RV32_CELL_BITS); - - // Note: revisit in the future to do 2 block writes of 16 cells instead of 1 block write of - // 32 cells This could be beneficial as the output is often an input for - // another hash - self.memory_bridge - .write( - MemoryAddress::new(AB::Expr::from_canonical_u32(RV32_MEMORY_AS), dst_ptr_val), - result, - timestamp_pp() + time_delta.clone(), - &local_cols.writes_aux, - ) - .eval(builder, is_last_row.clone()); - - self.execution_bridge - .execute_and_increment_pc( - AB::Expr::from_canonical_usize(Rv32Sha256Opcode::SHA256.global_opcode().as_usize()), - [ - local_cols.rd_ptr.into(), - local_cols.rs1_ptr.into(), - local_cols.rs2_ptr.into(), - AB::Expr::from_canonical_u32(RV32_REGISTER_AS), - AB::Expr::from_canonical_u32(RV32_MEMORY_AS), - ], - local_cols.from_state, - AB::Expr::from_canonical_usize(timestamp_delta) + time_delta.clone(), - ) - .eval(builder, is_last_row.clone()); - - // Assert that we read the correct length of the message - let len_val = compose::(&local_cols.len_data.map(|x| x.into()), RV32_CELL_BITS); - builder - .when(is_last_row.clone()) - .assert_eq(local_cols.control.len, len_val); - // Assert that we started reading from the correct pointer initially - let src_val = compose::(&local_cols.src_ptr.map(|x| x.into()), RV32_CELL_BITS); - builder - .when(is_last_row.clone()) - .assert_eq(local_cols.control.read_ptr, src_val + read_ptr_delta); - // Assert that we started reading from the correct timestamp - builder.when(is_last_row.clone()).assert_eq( - local_cols.control.cur_timestamp, - local_cols.from_state.timestamp + AB::Expr::from_canonical_u32(3) + time_delta, - ); - } -} diff --git a/extensions/sha256/circuit/src/sha256_chip/columns.rs b/extensions/sha256/circuit/src/sha256_chip/columns.rs deleted file mode 100644 index 38c13a0f73..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/columns.rs +++ /dev/null @@ -1,70 +0,0 @@ -//! WARNING: the order of fields in the structs is important, do not change it - -use openvm_circuit::{ - arch::ExecutionState, - system::memory::offline_checker::{MemoryReadAuxCols, MemoryWriteAuxCols}, -}; -use openvm_circuit_primitives::AlignedBorrow; -use openvm_instructions::riscv::RV32_REGISTER_NUM_LIMBS; -use openvm_sha256_air::{Sha256DigestCols, Sha256RoundCols}; - -use super::{SHA256_REGISTER_READS, SHA256_WRITE_SIZE}; - -/// the first 16 rows of every SHA256 block will be of type Sha256VmRoundCols and the last row will -/// be of type Sha256VmDigestCols -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmRoundCols { - pub control: Sha256VmControlCols, - pub inner: Sha256RoundCols, - pub read_aux: MemoryReadAuxCols, -} - -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmDigestCols { - pub control: Sha256VmControlCols, - pub inner: Sha256DigestCols, - - pub from_state: ExecutionState, - /// It is counter intuitive, but we will constrain the register reads on the very last row of - /// every message - pub rd_ptr: T, - pub rs1_ptr: T, - pub rs2_ptr: T, - pub dst_ptr: [T; RV32_REGISTER_NUM_LIMBS], - pub src_ptr: [T; RV32_REGISTER_NUM_LIMBS], - pub len_data: [T; RV32_REGISTER_NUM_LIMBS], - pub register_reads_aux: [MemoryReadAuxCols; SHA256_REGISTER_READS], - pub writes_aux: MemoryWriteAuxCols, -} - -/// These are the columns that are used on both round and digest rows -#[repr(C)] -#[derive(Clone, Copy, Debug, AlignedBorrow)] -pub struct Sha256VmControlCols { - /// Note: We will use the buffer in `inner.message_schedule` as the message data - /// This is the length of the entire message in bytes - pub len: T, - /// Need to keep timestamp and read_ptr since block reads don't have the necessary information - pub cur_timestamp: T, - pub read_ptr: T, - /// Padding flags which will be used to encode the the number of non-padding cells in the - /// current row - pub pad_flags: [T; 6], - /// A boolean flag that indicates whether a padding already occurred - pub padding_occurred: T, -} - -/// Width of the Sha256VmControlCols -pub const SHA256VM_CONTROL_WIDTH: usize = Sha256VmControlCols::::width(); -/// Width of the Sha256VmRoundCols -pub const SHA256VM_ROUND_WIDTH: usize = Sha256VmRoundCols::::width(); -/// Width of the Sha256VmDigestCols -pub const SHA256VM_DIGEST_WIDTH: usize = Sha256VmDigestCols::::width(); -/// Width of the Sha256Cols -pub const SHA256VM_WIDTH: usize = if SHA256VM_ROUND_WIDTH > SHA256VM_DIGEST_WIDTH { - SHA256VM_ROUND_WIDTH -} else { - SHA256VM_DIGEST_WIDTH -}; diff --git a/extensions/sha256/circuit/src/sha256_chip/tests.rs b/extensions/sha256/circuit/src/sha256_chip/tests.rs deleted file mode 100644 index 4f72ffd333..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/tests.rs +++ /dev/null @@ -1,376 +0,0 @@ -use std::{array, sync::Arc}; - -use hex::FromHex; -use openvm_circuit::{ - arch::{ - testing::{ - memory::gen_pointer, TestBuilder, TestChipHarness, VmChipTestBuilder, - BITWISE_OP_LOOKUP_BUS, - }, - Arena, MatrixRecordArena, PreflightExecutor, - }, - system::{memory::SharedMemoryHelper, SystemPort}, - utils::get_random_message, -}; -use openvm_circuit_primitives::bitwise_op_lookup::{ - BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip, - SharedBitwiseOperationLookupChip, -}; -use openvm_instructions::{ - instruction::Instruction, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS}, - LocalOpcode, -}; -use openvm_sha256_air::{get_sha256_num_blocks, SHA256_BLOCK_U8S}; -use openvm_sha256_transpiler::Rv32Sha256Opcode::{self, *}; -use openvm_stark_backend::{interaction::BusIndex, p3_field::FieldAlgebra}; -use openvm_stark_sdk::{p3_baby_bear::BabyBear, utils::create_seeded_rng}; -use rand::{rngs::StdRng, Rng}; -#[cfg(feature = "cuda")] -use { - crate::{Sha256VmChipGpu, Sha256VmRecordMut}, - openvm_circuit::arch::testing::{ - default_bitwise_lookup_bus, GpuChipTestBuilder, GpuTestChipHarness, - }, -}; - -use super::{Sha256VmAir, Sha256VmChip, Sha256VmExecutor}; -use crate::{sha256_solve, Sha256VmDigestCols, Sha256VmFiller, Sha256VmRoundCols}; - -type F = BabyBear; -const SELF_BUS_IDX: BusIndex = 28; -const MAX_INS_CAPACITY: usize = 4096; -type Harness = TestChipHarness, RA>; - -fn create_harness_fields( - system_port: SystemPort, - bitwise_chip: Arc>, - memory_helper: SharedMemoryHelper, - address_bits: usize, -) -> (Sha256VmAir, Sha256VmExecutor, Sha256VmChip) { - let air = Sha256VmAir::new(system_port, bitwise_chip.bus(), address_bits, SELF_BUS_IDX); - let executor = Sha256VmExecutor::new(Rv32Sha256Opcode::CLASS_OFFSET, address_bits); - let chip = Sha256VmChip::new( - Sha256VmFiller::new(bitwise_chip, address_bits), - memory_helper, - ); - (air, executor, chip) -} - -fn create_harness( - tester: &mut VmChipTestBuilder, -) -> ( - Harness, - ( - BitwiseOperationLookupAir, - SharedBitwiseOperationLookupChip, - ), -) { - let bitwise_bus = BitwiseOperationLookupBus::new(BITWISE_OP_LOOKUP_BUS); - let bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - let (air, executor, chip) = create_harness_fields( - tester.system_port(), - bitwise_chip.clone(), - tester.memory_helper(), - tester.address_bits(), - ); - let harness = Harness::::with_capacity(executor, air, chip, MAX_INS_CAPACITY); - (harness, (bitwise_chip.air, bitwise_chip)) -} - -fn set_and_execute>( - tester: &mut impl TestBuilder, - executor: &mut E, - arena: &mut RA, - rng: &mut StdRng, - opcode: Rv32Sha256Opcode, - message: Option<&[u8]>, - len: Option, -) { - let len = len.unwrap_or(rng.gen_range(1..3000)); - let tmp = get_random_message(rng, len); - let message: &[u8] = message.unwrap_or(&tmp); - let len = message.len(); - - let rd = gen_pointer(rng, 4); - let rs1 = gen_pointer(rng, 4); - let rs2 = gen_pointer(rng, 4); - - let dst_ptr = gen_pointer(rng, 4); - let src_ptr = gen_pointer(rng, 4); - tester.write(1, rd, dst_ptr.to_le_bytes().map(F::from_canonical_u8)); - tester.write(1, rs1, src_ptr.to_le_bytes().map(F::from_canonical_u8)); - tester.write(1, rs2, len.to_le_bytes().map(F::from_canonical_u8)); - - // Adding random memory after the message - let num_blocks = get_sha256_num_blocks(len as u32) as usize; - for offset in (0..num_blocks * SHA256_BLOCK_U8S).step_by(4) { - let chunk: [F; 4] = array::from_fn(|i| { - if offset + i < message.len() { - F::from_canonical_u8(message[offset + i]) - } else { - F::from_canonical_u8(rng.gen()) - } - }); - - tester.write(RV32_MEMORY_AS as usize, src_ptr + offset, chunk); - } - - tester.execute( - executor, - arena, - &Instruction::from_usize(opcode.global_opcode(), [rd, rs1, rs2, 1, 2]), - ); - - let output = sha256_solve(message); - assert_eq!( - output.map(F::from_canonical_u8), - tester.read::<32>(RV32_MEMORY_AS as usize, dst_ptr) - ); -} - -/////////////////////////////////////////////////////////////////////////////////////// -/// POSITIVE TESTS -/// -/// Randomly generate computations and execute, ensuring that the generated trace -/// passes all constraints. -/////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn rand_sha256_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let (mut harness, bitwise) = create_harness(&mut tester); - - let num_ops: usize = 10; - for _ in 0..num_ops { - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.arena, - &mut rng, - SHA256, - None, - None, - ); - } - - let tester = tester - .build() - .load(harness) - .load_periphery(bitwise) - .finalize(); - tester.simple_test().expect("Verification failed"); -} - -#[test] -fn sha256_edge_test_lengths() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let (mut harness, bitwise) = create_harness(&mut tester); - - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ( - "98c1c0bdb7d5fea9a88859f06c6c439f", - "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", - ), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - ]; - - for (input, _) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.arena, - &mut rng, - SHA256, - Some(&input), - None, - ); - } - - // check every possible input length modulo 64 - for i in 65..=128 { - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.arena, - &mut rng, - SHA256, - None, - Some(i), - ); - } - - let tester = tester - .build() - .load(harness) - .load_periphery(bitwise) - .finalize(); - tester.simple_test().expect("Verification failed"); -} - -/////////////////////////////////////////////////////////////////////////////////////// -/// SANITY TESTS -/// -/// Ensure that solve functions produce the correct results. -/////////////////////////////////////////////////////////////////////////////////////// -#[test] -fn execute_roundtrip_sanity_test() { - let mut rng = create_seeded_rng(); - let mut tester = VmChipTestBuilder::default(); - let (mut harness, _) = create_harness::>(&mut tester); - - println!( - "Sha256VmDigestCols::width(): {}", - Sha256VmDigestCols::::width() - ); - println!( - "Sha256VmRoundCols::width(): {}", - Sha256VmRoundCols::::width() - ); - let num_tests: usize = 1; - for _ in 0..num_tests { - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.arena, - &mut rng, - SHA256, - None, - None, - ); - } -} - -#[test] -fn sha256_solve_sanity_check() { - let input = b"Axiom is the best! Axiom is the best! Axiom is the best! Axiom is the best!"; - let output = sha256_solve(input); - let expected: [u8; 32] = [ - 99, 196, 61, 185, 226, 212, 131, 80, 154, 248, 97, 108, 157, 55, 200, 226, 160, 73, 207, - 46, 245, 169, 94, 255, 42, 136, 193, 15, 40, 133, 173, 22, - ]; - assert_eq!(output, expected); -} - -// //////////////////////////////////////////////////////////////////////////////////// -// CUDA TESTS -// -// Ensure GPU tracegen is equivalent to CPU tracegen -// //////////////////////////////////////////////////////////////////////////////////// - -#[cfg(feature = "cuda")] -type GpuHarness = - GpuTestChipHarness>; - -#[cfg(feature = "cuda")] -fn create_cuda_harness(tester: &GpuChipTestBuilder) -> GpuHarness { - const GPU_MAX_INS_CAPACITY: usize = 8192; - - let bitwise_bus = default_bitwise_lookup_bus(); - let dummy_bitwise_chip = Arc::new(BitwiseOperationLookupChip::::new( - bitwise_bus, - )); - - let (air, executor, cpu_chip) = create_harness_fields( - tester.system_port(), - dummy_bitwise_chip, - tester.dummy_memory_helper(), - tester.address_bits(), - ); - let gpu_chip = Sha256VmChipGpu::new( - tester.range_checker(), - tester.bitwise_op_lookup(), - tester.address_bits() as u32, - tester.timestamp_max_bits() as u32, - ); - - GpuTestChipHarness::with_capacity(executor, air, gpu_chip, cpu_chip, GPU_MAX_INS_CAPACITY) -} - -#[cfg(feature = "cuda")] -#[test] -fn test_cuda_sha256_tracegen() { - let mut rng = create_seeded_rng(); - let mut tester = - GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); - - let mut harness = create_cuda_harness(&tester); - - let num_ops = 70; - for i in 1..=num_ops { - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.dense_arena, - &mut rng, - SHA256, - None, - Some(i), - ); - } - - harness - .dense_arena - .get_record_seeker::() - .transfer_to_matrix_arena(&mut harness.matrix_arena); - - tester - .build() - .load_gpu_harness(harness) - .finalize() - .simple_test() - .unwrap(); -} - -#[cfg(feature = "cuda")] -#[test] -fn test_cuda_sha256_known_vectors() { - let mut rng = create_seeded_rng(); - let mut tester = - GpuChipTestBuilder::default().with_bitwise_op_lookup(default_bitwise_lookup_bus()); - - let mut harness = create_cuda_harness(&tester); - - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ( - "98c1c0bdb7d5fea9a88859f06c6c439f", - "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", - ), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - ]; - - for (input, _) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - - set_and_execute( - &mut tester, - &mut harness.executor, - &mut harness.dense_arena, - &mut rng, - SHA256, - Some(&input), - None, - ); - } - - harness - .dense_arena - .get_record_seeker::() - .transfer_to_matrix_arena(&mut harness.matrix_arena); - - tester - .build() - .load_gpu_harness(harness) - .finalize() - .simple_test() - .unwrap(); -} diff --git a/extensions/sha256/circuit/src/sha256_chip/trace.rs b/extensions/sha256/circuit/src/sha256_chip/trace.rs deleted file mode 100644 index 7fc5c7062c..0000000000 --- a/extensions/sha256/circuit/src/sha256_chip/trace.rs +++ /dev/null @@ -1,624 +0,0 @@ -use std::{ - array, - borrow::{Borrow, BorrowMut}, - cmp::min, -}; - -use openvm_circuit::{ - arch::*, - system::memory::{ - offline_checker::{MemoryReadAuxRecord, MemoryWriteBytesAuxRecord}, - online::TracingMemory, - MemoryAuxColsFactory, - }, -}; -use openvm_circuit_primitives::AlignedBytesBorrow; -use openvm_instructions::{ - instruction::Instruction, - program::DEFAULT_PC_STEP, - riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS}, - LocalOpcode, -}; -use openvm_rv32im_circuit::adapters::{read_rv32_register, tracing_read, tracing_write}; -use openvm_sha256_air::{ - get_flag_pt_array, get_sha256_num_blocks, Sha256FillerHelper, SHA256_BLOCK_BITS, SHA256_H, - SHA256_ROWS_PER_BLOCK, -}; -use openvm_sha256_transpiler::Rv32Sha256Opcode; -use openvm_stark_backend::{ - p3_field::PrimeField32, - p3_matrix::{dense::RowMajorMatrix, Matrix}, - p3_maybe_rayon::prelude::*, -}; - -use super::{ - Sha256VmDigestCols, Sha256VmExecutor, Sha256VmRoundCols, SHA256VM_CONTROL_WIDTH, - SHA256VM_DIGEST_WIDTH, -}; -use crate::{ - sha256_chip::{PaddingFlags, SHA256_READ_SIZE, SHA256_REGISTER_READS, SHA256_WRITE_SIZE}, - sha256_solve, Sha256VmControlCols, Sha256VmFiller, SHA256VM_ROUND_WIDTH, SHA256VM_WIDTH, - SHA256_BLOCK_CELLS, SHA256_MAX_MESSAGE_LEN, SHA256_NUM_READ_ROWS, -}; - -#[derive(Clone, Copy)] -pub struct Sha256VmMetadata { - pub num_blocks: u32, -} - -impl MultiRowMetadata for Sha256VmMetadata { - #[inline(always)] - fn get_num_rows(&self) -> usize { - self.num_blocks as usize * SHA256_ROWS_PER_BLOCK - } -} - -pub(crate) type Sha256VmRecordLayout = MultiRowLayout; - -#[repr(C)] -#[derive(AlignedBytesBorrow, Debug, Clone)] -pub struct Sha256VmRecordHeader { - pub from_pc: u32, - pub timestamp: u32, - pub rd_ptr: u32, - pub rs1_ptr: u32, - pub rs2_ptr: u32, - pub dst_ptr: u32, - pub src_ptr: u32, - pub len: u32, - - pub register_reads_aux: [MemoryReadAuxRecord; SHA256_REGISTER_READS], - pub write_aux: MemoryWriteBytesAuxRecord, -} - -pub struct Sha256VmRecordMut<'a> { - pub inner: &'a mut Sha256VmRecordHeader, - // Having a continuous slice of the input is useful for fast hashing in `execute` - pub input: &'a mut [u8], - pub read_aux: &'a mut [MemoryReadAuxRecord], -} - -/// Custom borrowing that splits the buffer into a fixed `Sha256VmRecord` header -/// followed by a slice of `u8`'s of length `SHA256_BLOCK_CELLS * num_blocks` where `num_blocks` is -/// provided at runtime, followed by a slice of `MemoryReadAuxRecord`'s of length -/// `SHA256_NUM_READ_ROWS * num_blocks`. Uses `align_to_mut()` to make sure the slice is properly -/// aligned to `MemoryReadAuxRecord`. Has debug assertions that check the size and alignment of the -/// slices. -impl<'a> CustomBorrow<'a, Sha256VmRecordMut<'a>, Sha256VmRecordLayout> for [u8] { - fn custom_borrow(&'a mut self, layout: Sha256VmRecordLayout) -> Sha256VmRecordMut<'a> { - // SAFETY: - // - Caller guarantees through the layout that self has sufficient length for all splits and - // constants are guaranteed <= self.len() by layout precondition - let (header_buf, rest) = - unsafe { self.split_at_mut_unchecked(size_of::()) }; - - // SAFETY: - // - layout guarantees rest has sufficient length for input data - // - The layout size calculation includes num_blocks * SHA256_BLOCK_CELLS bytes after header - // - num_blocks is derived from the message length ensuring correct sizing - // - SHA256_BLOCK_CELLS is a compile-time constant (64 bytes per block) - let (input, rest) = unsafe { - rest.split_at_mut_unchecked((layout.metadata.num_blocks as usize) * SHA256_BLOCK_CELLS) - }; - - // SAFETY: - // - rest is a valid mutable slice from the previous split - // - align_to_mut guarantees the middle slice is properly aligned for MemoryReadAuxRecord - // - The subslice operation [..num_blocks * SHA256_NUM_READ_ROWS] validates sufficient - // capacity - // - Layout calculation ensures space for alignment padding plus required aux records - let (_, read_aux_buf, _) = unsafe { rest.align_to_mut::() }; - Sha256VmRecordMut { - inner: header_buf.borrow_mut(), - input, - read_aux: &mut read_aux_buf - [..(layout.metadata.num_blocks as usize) * SHA256_NUM_READ_ROWS], - } - } - - unsafe fn extract_layout(&self) -> Sha256VmRecordLayout { - let header: &Sha256VmRecordHeader = self.borrow(); - Sha256VmRecordLayout { - metadata: Sha256VmMetadata { - num_blocks: get_sha256_num_blocks(header.len), - }, - } - } -} - -impl SizedRecord for Sha256VmRecordMut<'_> { - fn size(layout: &Sha256VmRecordLayout) -> usize { - let mut total_len = size_of::(); - total_len += layout.metadata.num_blocks as usize * SHA256_BLOCK_CELLS; - // Align the pointer to the alignment of `MemoryReadAuxRecord` - total_len = total_len.next_multiple_of(align_of::()); - total_len += layout.metadata.num_blocks as usize - * SHA256_NUM_READ_ROWS - * size_of::(); - total_len - } - - fn alignment(_layout: &Sha256VmRecordLayout) -> usize { - align_of::() - } -} - -impl PreflightExecutor for Sha256VmExecutor -where - F: PrimeField32, - for<'buf> RA: RecordArena<'buf, Sha256VmRecordLayout, Sha256VmRecordMut<'buf>>, -{ - fn get_opcode_name(&self, _: usize) -> String { - format!("{:?}", Rv32Sha256Opcode::SHA256) - } - - fn execute( - &self, - state: VmStateMut, - instruction: &Instruction, - ) -> Result<(), ExecutionError> { - let Instruction { - opcode, - a, - b, - c, - d, - e, - .. - } = instruction; - debug_assert_eq!(*opcode, Rv32Sha256Opcode::SHA256.global_opcode()); - debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS); - debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS); - - // Reading the length first to allocate a record of correct size - let len = read_rv32_register(state.memory.data(), c.as_canonical_u32()); - - let num_blocks = get_sha256_num_blocks(len); - let record = state.ctx.alloc(MultiRowLayout { - metadata: Sha256VmMetadata { num_blocks }, - }); - - record.inner.from_pc = *state.pc; - record.inner.timestamp = state.memory.timestamp(); - record.inner.rd_ptr = a.as_canonical_u32(); - record.inner.rs1_ptr = b.as_canonical_u32(); - record.inner.rs2_ptr = c.as_canonical_u32(); - - record.inner.dst_ptr = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rd_ptr, - &mut record.inner.register_reads_aux[0].prev_timestamp, - )); - record.inner.src_ptr = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rs1_ptr, - &mut record.inner.register_reads_aux[1].prev_timestamp, - )); - record.inner.len = u32::from_le_bytes(tracing_read( - state.memory, - RV32_REGISTER_AS, - record.inner.rs2_ptr, - &mut record.inner.register_reads_aux[2].prev_timestamp, - )); - - // we will read [num_blocks] * [SHA256_BLOCK_CELLS] cells but only [len] cells will be used - debug_assert!( - record.inner.src_ptr as usize + num_blocks as usize * SHA256_BLOCK_CELLS - <= (1 << self.pointer_max_bits) - ); - debug_assert!( - record.inner.dst_ptr as usize + SHA256_WRITE_SIZE <= (1 << self.pointer_max_bits) - ); - // We don't support messages longer than 2^29 bytes - debug_assert!(record.inner.len < SHA256_MAX_MESSAGE_LEN as u32); - - for block_idx in 0..num_blocks as usize { - // Reads happen on the first 4 rows of each block - for row in 0..SHA256_NUM_READ_ROWS { - let read_idx = block_idx * SHA256_NUM_READ_ROWS + row; - let row_input: [u8; SHA256_READ_SIZE] = tracing_read( - state.memory, - RV32_MEMORY_AS, - record.inner.src_ptr + (read_idx * SHA256_READ_SIZE) as u32, - &mut record.read_aux[read_idx].prev_timestamp, - ); - record.input[read_idx * SHA256_READ_SIZE..(read_idx + 1) * SHA256_READ_SIZE] - .copy_from_slice(&row_input); - } - } - - let output = sha256_solve(&record.input[..len as usize]); - tracing_write( - state.memory, - RV32_MEMORY_AS, - record.inner.dst_ptr, - output, - &mut record.inner.write_aux.prev_timestamp, - &mut record.inner.write_aux.prev_data, - ); - - *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP); - - Ok(()) - } -} - -impl TraceFiller for Sha256VmFiller { - fn fill_trace( - &self, - mem_helper: &MemoryAuxColsFactory, - trace_matrix: &mut RowMajorMatrix, - rows_used: usize, - ) { - if rows_used == 0 { - return; - } - - let mut chunks = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK); - let mut sizes = Vec::with_capacity(trace_matrix.height() / SHA256_ROWS_PER_BLOCK); - let mut trace = &mut trace_matrix.values[..]; - let mut num_blocks_so_far = 0; - - // First pass over the trace to get the number of blocks for each instruction - // and divide the matrix into chunks of needed sizes - loop { - if num_blocks_so_far * SHA256_ROWS_PER_BLOCK >= rows_used { - // Push all the padding rows as a single chunk and break - chunks.push(trace); - sizes.push((0, num_blocks_so_far)); - break; - } else { - // SAFETY: - // - caller ensures `trace` contains a valid record representation that was - // previously written by the executor - // - header is the first element of the record - let record: &Sha256VmRecordHeader = - unsafe { get_record_from_slice(&mut trace, ()) }; - let num_blocks = ((record.len << 3) as usize + 1 + 64).div_ceil(SHA256_BLOCK_BITS); - let (chunk, rest) = - trace.split_at_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK * num_blocks); - chunks.push(chunk); - sizes.push((num_blocks, num_blocks_so_far)); - num_blocks_so_far += num_blocks; - trace = rest; - } - } - - // During the first pass we will fill out most of the matrix - // But there are some cells that can't be generated by the first pass so we will do a second - // pass over the matrix later - chunks.par_iter_mut().zip(sizes.par_iter()).for_each( - |(slice, (num_blocks, global_block_offset))| { - if global_block_offset * SHA256_ROWS_PER_BLOCK >= rows_used { - // Fill in the invalid rows - slice.par_chunks_mut(SHA256VM_WIDTH).for_each(|row| { - // Need to get rid of the accidental garbage data that might overflow the - // F's prime field. Unfortunately, there is no good way around this - // SAFETY: - // - row has exactly SHA256VM_WIDTH elements - // - We're zeroing all SHA256VM_WIDTH elements to clear any garbage data - // that might overflow the field - // - Casting F* to u8* preserves validity for write_bytes operation - // - SHA256VM_WIDTH * size_of::() correctly calculates total bytes to - // zero - unsafe { - std::ptr::write_bytes( - row.as_mut_ptr() as *mut u8, - 0, - SHA256VM_WIDTH * size_of::(), - ); - } - let cols: &mut Sha256VmRoundCols = - row[..SHA256VM_ROUND_WIDTH].borrow_mut(); - self.inner.generate_default_row(&mut cols.inner); - }); - return; - } - - // SAFETY: - // - caller ensures `trace` contains a valid record representation that was - // previously written by the executor - // - slice contains a valid Sha256VmRecord with the exact layout specified - // - get_record_from_slice will correctly split the buffer into header, input, and - // aux components based on this layout - let record: Sha256VmRecordMut = unsafe { - get_record_from_slice( - slice, - Sha256VmRecordLayout { - metadata: Sha256VmMetadata { - num_blocks: *num_blocks as u32, - }, - }, - ) - }; - - let mut input: Vec = Vec::with_capacity(SHA256_BLOCK_CELLS * num_blocks); - input.extend_from_slice(record.input); - let mut padded_input = input.clone(); - let len = record.inner.len as usize; - let padded_input_len = padded_input.len(); - padded_input[len] = 1 << (RV32_CELL_BITS - 1); - padded_input[len + 1..padded_input_len - 4].fill(0); - padded_input[padded_input_len - 4..] - .copy_from_slice(&((len as u32) << 3).to_be_bytes()); - - let mut prev_hashes = Vec::with_capacity(*num_blocks); - prev_hashes.push(SHA256_H); - for i in 0..*num_blocks - 1 { - prev_hashes.push(Sha256FillerHelper::get_block_hash( - &prev_hashes[i], - padded_input[i * SHA256_BLOCK_CELLS..(i + 1) * SHA256_BLOCK_CELLS] - .try_into() - .unwrap(), - )); - } - // Copy the read aux records and input to another place to safely fill in the trace - // matrix without overwriting the record - let mut read_aux_records = Vec::with_capacity(SHA256_NUM_READ_ROWS * num_blocks); - read_aux_records.extend_from_slice(record.read_aux); - let vm_record = record.inner.clone(); - - slice - .par_chunks_exact_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK) - .enumerate() - .for_each(|(block_idx, block_slice)| { - // Need to get rid of the accidental garbage data that might overflow the - // F's prime field. Unfortunately, there is no good way around this - // SAFETY: - // - block_slice comes from par_chunks_exact_mut with exact size guarantee - // - Length is SHA256_ROWS_PER_BLOCK * SHA256VM_WIDTH * size_of::() bytes - // - Zeroing entire blocks prevents using garbage data - // - The subsequent trace filling will overwrite with valid values - unsafe { - std::ptr::write_bytes( - block_slice.as_mut_ptr() as *mut u8, - 0, - SHA256_ROWS_PER_BLOCK * SHA256VM_WIDTH * size_of::(), - ); - } - self.fill_block_trace::( - block_slice, - &vm_record, - &read_aux_records[block_idx * SHA256_NUM_READ_ROWS - ..(block_idx + 1) * SHA256_NUM_READ_ROWS], - &input[block_idx * SHA256_BLOCK_CELLS - ..(block_idx + 1) * SHA256_BLOCK_CELLS], - &padded_input[block_idx * SHA256_BLOCK_CELLS - ..(block_idx + 1) * SHA256_BLOCK_CELLS], - block_idx == *num_blocks - 1, - *global_block_offset + block_idx, - block_idx, - prev_hashes[block_idx], - mem_helper, - ); - }); - }, - ); - - // Do a second pass over the trace to fill in the missing values - // Note, we need to skip the very first row - trace_matrix.values[SHA256VM_WIDTH..] - .par_chunks_mut(SHA256VM_WIDTH * SHA256_ROWS_PER_BLOCK) - .take(rows_used / SHA256_ROWS_PER_BLOCK) - .for_each(|chunk| { - self.inner - .generate_missing_cells(chunk, SHA256VM_WIDTH, SHA256VM_CONTROL_WIDTH); - }); - } -} - -impl Sha256VmFiller { - #[allow(clippy::too_many_arguments)] - fn fill_block_trace( - &self, - block_slice: &mut [F], - record: &Sha256VmRecordHeader, - read_aux_records: &[MemoryReadAuxRecord], - input: &[u8], - padded_input: &[u8], - is_last_block: bool, - global_block_idx: usize, - local_block_idx: usize, - prev_hash: [u32; 8], - mem_helper: &MemoryAuxColsFactory, - ) { - debug_assert_eq!(input.len(), SHA256_BLOCK_CELLS); - debug_assert_eq!(padded_input.len(), SHA256_BLOCK_CELLS); - debug_assert_eq!(read_aux_records.len(), SHA256_NUM_READ_ROWS); - - let padded_input = array::from_fn(|i| { - u32::from_be_bytes(padded_input[i * 4..(i + 1) * 4].try_into().unwrap()) - }); - - let block_start_timestamp = record.timestamp - + (SHA256_REGISTER_READS + SHA256_NUM_READ_ROWS * local_block_idx) as u32; - - let read_cells = (SHA256_BLOCK_CELLS * local_block_idx) as u32; - let block_start_read_ptr = record.src_ptr + read_cells; - - let message_left = if record.len <= read_cells { - 0 - } else { - (record.len - read_cells) as usize - }; - - // -1 means that padding occurred before the start of the block - // 18 means that no padding occurred on this block - let first_padding_row = if record.len < read_cells { - -1 - } else if message_left < SHA256_BLOCK_CELLS { - (message_left / SHA256_READ_SIZE) as i32 - } else { - 18 - }; - - // Fill in the VM columns first because the inner `carry_or_buffer` needs to be filled in - block_slice - .par_chunks_exact_mut(SHA256VM_WIDTH) - .enumerate() - .for_each(|(row_idx, row_slice)| { - // Handle round rows and digest row separately - if row_idx == SHA256_ROWS_PER_BLOCK - 1 { - // This is a digest row - let digest_cols: &mut Sha256VmDigestCols = - row_slice[..SHA256VM_DIGEST_WIDTH].borrow_mut(); - digest_cols.from_state.timestamp = F::from_canonical_u32(record.timestamp); - digest_cols.from_state.pc = F::from_canonical_u32(record.from_pc); - digest_cols.rd_ptr = F::from_canonical_u32(record.rd_ptr); - digest_cols.rs1_ptr = F::from_canonical_u32(record.rs1_ptr); - digest_cols.rs2_ptr = F::from_canonical_u32(record.rs2_ptr); - digest_cols.dst_ptr = record.dst_ptr.to_le_bytes().map(F::from_canonical_u8); - digest_cols.src_ptr = record.src_ptr.to_le_bytes().map(F::from_canonical_u8); - digest_cols.len_data = record.len.to_le_bytes().map(F::from_canonical_u8); - if is_last_block { - digest_cols - .register_reads_aux - .iter_mut() - .zip(record.register_reads_aux.iter()) - .enumerate() - .for_each(|(idx, (cols_read, record_read))| { - mem_helper.fill( - record_read.prev_timestamp, - record.timestamp + idx as u32, - cols_read.as_mut(), - ); - }); - digest_cols - .writes_aux - .set_prev_data(record.write_aux.prev_data.map(F::from_canonical_u8)); - // In the last block we do `SHA256_NUM_READ_ROWS` reads and then write the - // result thus the timestamp of the write is - // `block_start_timestamp + SHA256_NUM_READ_ROWS` - mem_helper.fill( - record.write_aux.prev_timestamp, - block_start_timestamp + SHA256_NUM_READ_ROWS as u32, - digest_cols.writes_aux.as_mut(), - ); - // Need to range check the destination and source pointers - let msl_rshift: u32 = - ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32; - let msl_lshift: u32 = (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - - self.pointer_max_bits) - as u32; - self.bitwise_lookup_chip.request_range( - (record.dst_ptr >> msl_rshift) << msl_lshift, - (record.src_ptr >> msl_rshift) << msl_lshift, - ); - } else { - // Filling in zeros to make sure the accidental garbage data doesn't - // overflow the prime - digest_cols.register_reads_aux.iter_mut().for_each(|aux| { - mem_helper.fill_zero(aux.as_mut()); - }); - digest_cols - .writes_aux - .set_prev_data([F::ZERO; SHA256_WRITE_SIZE]); - mem_helper.fill_zero(digest_cols.writes_aux.as_mut()); - } - digest_cols.inner.flags.is_last_block = F::from_bool(is_last_block); - digest_cols.inner.flags.is_digest_row = F::from_bool(true); - } else { - // This is a round row - let round_cols: &mut Sha256VmRoundCols = - row_slice[..SHA256VM_ROUND_WIDTH].borrow_mut(); - // Take care of the first 4 round rows (aka read rows) - if row_idx < SHA256_NUM_READ_ROWS { - round_cols - .inner - .message_schedule - .carry_or_buffer - .as_flattened_mut() - .iter_mut() - .zip( - input[row_idx * SHA256_READ_SIZE..(row_idx + 1) * SHA256_READ_SIZE] - .iter(), - ) - .for_each(|(cell, data)| { - *cell = F::from_canonical_u8(*data); - }); - mem_helper.fill( - read_aux_records[row_idx].prev_timestamp, - block_start_timestamp + row_idx as u32, - round_cols.read_aux.as_mut(), - ); - } else { - mem_helper.fill_zero(round_cols.read_aux.as_mut()); - } - } - // Fill in the control cols, doesn't matter if it is a round or digest row - let control_cols: &mut Sha256VmControlCols = - row_slice[..SHA256VM_CONTROL_WIDTH].borrow_mut(); - control_cols.len = F::from_canonical_u32(record.len); - // Only the first `SHA256_NUM_READ_ROWS` rows increment the timestamp and read ptr - control_cols.cur_timestamp = F::from_canonical_u32( - block_start_timestamp + min(row_idx, SHA256_NUM_READ_ROWS) as u32, - ); - control_cols.read_ptr = F::from_canonical_u32( - block_start_read_ptr - + (SHA256_READ_SIZE * min(row_idx, SHA256_NUM_READ_ROWS)) as u32, - ); - - // Fill in the padding flags - if row_idx < SHA256_NUM_READ_ROWS { - #[allow(clippy::comparison_chain)] - if (row_idx as i32) < first_padding_row { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - PaddingFlags::NotPadding as usize, - ) - .map(F::from_canonical_u32); - } else if row_idx as i32 == first_padding_row { - let len = message_left - row_idx * SHA256_READ_SIZE; - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - if row_idx == 3 && is_last_block { - PaddingFlags::FirstPadding0_LastRow - } else { - PaddingFlags::FirstPadding0 - } as usize - + len, - ) - .map(F::from_canonical_u32); - } else { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - if row_idx == 3 && is_last_block { - PaddingFlags::EntirePaddingLastRow - } else { - PaddingFlags::EntirePadding - } as usize, - ) - .map(F::from_canonical_u32); - } - } else { - control_cols.pad_flags = get_flag_pt_array( - &self.padding_encoder, - PaddingFlags::NotConsidered as usize, - ) - .map(F::from_canonical_u32); - } - if is_last_block && row_idx == SHA256_ROWS_PER_BLOCK - 1 { - // If last digest row, then we set padding_occurred = 0 - control_cols.padding_occurred = F::ZERO; - } else { - control_cols.padding_occurred = - F::from_bool((row_idx as i32) >= first_padding_row); - } - }); - - // Fill in the inner trace when the `buffer_or_carry` is filled in - self.inner.generate_block_trace::( - block_slice, - SHA256VM_WIDTH, - SHA256VM_CONTROL_WIDTH, - &padded_input, - self.bitwise_lookup_chip.as_ref(), - &prev_hash, - is_last_block, - global_block_idx as u32 + 1, // global block index is 1-indexed - local_block_idx as u32, - ); - } -} diff --git a/extensions/sha256/guest/src/lib.rs b/extensions/sha256/guest/src/lib.rs deleted file mode 100644 index 8f7c072f4a..0000000000 --- a/extensions/sha256/guest/src/lib.rs +++ /dev/null @@ -1,69 +0,0 @@ -#![no_std] - -#[cfg(target_os = "zkvm")] -use openvm_platform::alloc::AlignedBuf; - -/// This is custom-0 defined in RISC-V spec document -pub const OPCODE: u8 = 0x0b; -pub const SHA256_FUNCT3: u8 = 0b100; -pub const SHA256_FUNCT7: u8 = 0x1; - -/// Native hook for sha256 -/// -/// # Safety -/// -/// The VM accepts the preimage by pointer and length, and writes the -/// 32-byte hash. -/// - `bytes` must point to an input buffer at least `len` long. -/// - `output` must point to a buffer that is at least 32-bytes long. -/// -/// [`sha2`]: https://nvlpubs.nist.gov/nistpubs/FIPS/NIST.FIPS.180-4.pdf -#[cfg(target_os = "zkvm")] -#[inline(always)] -#[no_mangle] -pub extern "C" fn zkvm_sha256_impl(bytes: *const u8, len: usize, output: *mut u8) { - // SAFETY: assuming safety assumptions of the inputs, we handle all cases where `bytes` or - // `output` are not aligned to 4 bytes. - // The minimum alignment required for the input and output buffers - const MIN_ALIGN: usize = 4; - // The preferred alignment for the input buffer, since the input is read in chunks of 16 bytes - const INPUT_ALIGN: usize = 16; - // The preferred alignment for the output buffer, since the output is written in chunks of 32 - // bytes - const OUTPUT_ALIGN: usize = 32; - unsafe { - if bytes as usize % MIN_ALIGN != 0 { - let aligned_buff = AlignedBuf::new(bytes, len, INPUT_ALIGN); - if output as usize % MIN_ALIGN != 0 { - let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); - __native_sha256(aligned_buff.ptr, len, aligned_out.ptr); - core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); - } else { - __native_sha256(aligned_buff.ptr, len, output); - } - } else { - if output as usize % MIN_ALIGN != 0 { - let aligned_out = AlignedBuf::uninit(32, OUTPUT_ALIGN); - __native_sha256(bytes, len, aligned_out.ptr); - core::ptr::copy_nonoverlapping(aligned_out.ptr as *const u8, output, 32); - } else { - __native_sha256(bytes, len, output); - } - }; - } -} - -/// sha256 intrinsic binding -/// -/// # Safety -/// -/// The VM accepts the preimage by pointer and length, and writes the -/// 32-byte hash. -/// - `bytes` must point to an input buffer at least `len` long. -/// - `output` must point to a buffer that is at least 32-bytes long. -/// - `bytes` and `output` must be 4-byte aligned. -#[cfg(target_os = "zkvm")] -#[inline(always)] -fn __native_sha256(bytes: *const u8, len: usize, output: *mut u8) { - openvm_platform::custom_insn_r!(opcode = OPCODE, funct3 = SHA256_FUNCT3, funct7 = SHA256_FUNCT7, rd = In output, rs1 = In bytes, rs2 = In len); -} diff --git a/extensions/sha256/transpiler/src/lib.rs b/extensions/sha256/transpiler/src/lib.rs deleted file mode 100644 index 6b13efe055..0000000000 --- a/extensions/sha256/transpiler/src/lib.rs +++ /dev/null @@ -1,46 +0,0 @@ -use openvm_instructions::{riscv::RV32_MEMORY_AS, LocalOpcode}; -use openvm_instructions_derive::LocalOpcode; -use openvm_sha256_guest::{OPCODE, SHA256_FUNCT3, SHA256_FUNCT7}; -use openvm_stark_backend::p3_field::PrimeField32; -use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput}; -use rrs_lib::instruction_formats::RType; -use strum::{EnumCount, EnumIter, FromRepr}; - -#[derive( - Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode, -)] -#[opcode_offset = 0x320] -#[repr(usize)] -pub enum Rv32Sha256Opcode { - SHA256, -} - -#[derive(Default)] -pub struct Sha256TranspilerExtension; - -impl TranspilerExtension for Sha256TranspilerExtension { - fn process_custom(&self, instruction_stream: &[u32]) -> Option> { - if instruction_stream.is_empty() { - return None; - } - let instruction_u32 = instruction_stream[0]; - let opcode = (instruction_u32 & 0x7f) as u8; - let funct3 = ((instruction_u32 >> 12) & 0b111) as u8; - - if (opcode, funct3) != (OPCODE, SHA256_FUNCT3) { - return None; - } - let dec_insn = RType::new(instruction_u32); - - if dec_insn.funct7 != SHA256_FUNCT7 as u32 { - return None; - } - let instruction = from_r_type( - Rv32Sha256Opcode::SHA256.global_opcode().as_usize(), - RV32_MEMORY_AS as usize, - &dec_insn, - true, - ); - Some(TranspilerOutput::one_to_one(instruction)) - } -} diff --git a/guest-libs/k256/Cargo.toml b/guest-libs/k256/Cargo.toml index 6e42d7e52c..81ea827ebe 100644 --- a/guest-libs/k256/Cargo.toml +++ b/guest-libs/k256/Cargo.toml @@ -35,8 +35,8 @@ openvm-transpiler.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true @@ -71,14 +71,14 @@ pkcs8 = ["ecdsa-core/pkcs8", "elliptic-curve/pkcs8"] precomputed-tables = ["arithmetic", "once_cell"] schnorr = ["arithmetic", "signature"] serde = ["ecdsa-core/serde", "elliptic-curve/serde"] -sha256 = [] +sha2 = [] test-vectors = [] # Internal feature for testing only. cuda = [ "openvm-circuit/cuda", "openvm-ecc-circuit/cuda", - "openvm-sha256-circuit/cuda", + "openvm-sha2-circuit/cuda", ] tco = ["openvm-circuit/tco"] diff --git a/guest-libs/k256/tests/lib.rs b/guest-libs/k256/tests/lib.rs index 59eb42dde3..7573d7a49d 100644 --- a/guest-libs/k256/tests/lib.rs +++ b/guest-libs/k256/tests/lib.rs @@ -13,7 +13,7 @@ mod guest_tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -99,7 +99,7 @@ mod guest_tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, Rv32WeierstrassConfigExecutor, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256ProverExt}; + use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2ProverExt}; use serde::{Deserialize, Serialize}; #[cfg(feature = "cuda")] use { @@ -128,14 +128,14 @@ mod guest_tests { #[config(generics = true)] pub weierstrass: Rv32WeierstrassConfig, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } impl EcdsaConfig { pub fn new(curves: Vec) -> Self { Self { weierstrass: Rv32WeierstrassConfig::new(curves), - sha256: Default::default(), + sha2: Default::default(), } } } @@ -179,8 +179,8 @@ mod guest_tests { )?; let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover( - &Sha256ProverExt, - &config.sha256, + &Sha2ProverExt, + &config.sha2, inventory, )?; Ok(chip_complex) @@ -214,8 +214,8 @@ mod guest_tests { )?; let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover( - &Sha256ProverExt, - &config.sha256, + &Sha2ProverExt, + &config.sha2, inventory, )?; Ok(chip_complex) @@ -237,7 +237,7 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(EcdsaBuilder, config, openvm_exe); Ok(()) diff --git a/guest-libs/p256/Cargo.toml b/guest-libs/p256/Cargo.toml index 852fa7af95..f928a8cc52 100644 --- a/guest-libs/p256/Cargo.toml +++ b/guest-libs/p256/Cargo.toml @@ -32,8 +32,8 @@ openvm-transpiler.workspace = true openvm-algebra-transpiler.workspace = true openvm-ecc-transpiler.workspace = true openvm-ecc-circuit.workspace = true -openvm-sha256-circuit.workspace = true -openvm-sha256-transpiler.workspace = true +openvm-sha2-circuit.workspace = true +openvm-sha2-transpiler.workspace = true openvm-rv32im-transpiler.workspace = true openvm-toolchain-tests.workspace = true @@ -60,7 +60,7 @@ jwk = ["elliptic-curve/jwk"] pem = ["elliptic-curve/pem", "ecdsa-core/pem", "pkcs8"] pkcs8 = ["ecdsa-core?/pkcs8", "elliptic-curve/pkcs8"] serde = ["ecdsa-core?/serde", "elliptic-curve/serde"] -sha256 = [] +sha2 = [] test-vectors = [] voprf = ["elliptic-curve/voprf"] @@ -68,7 +68,7 @@ voprf = ["elliptic-curve/voprf"] cuda = [ "openvm-circuit/cuda", "openvm-ecc-circuit/cuda", - "openvm-sha256-circuit/cuda", + "openvm-sha2-circuit/cuda", ] tco = ["openvm-circuit/tco"] diff --git a/guest-libs/p256/tests/lib.rs b/guest-libs/p256/tests/lib.rs index 9eaf2b2c74..19f8eb464e 100644 --- a/guest-libs/p256/tests/lib.rs +++ b/guest-libs/p256/tests/lib.rs @@ -13,7 +13,7 @@ mod guest_tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -99,7 +99,7 @@ mod guest_tests { CurveConfig, Rv32WeierstrassBuilder, Rv32WeierstrassConfig, Rv32WeierstrassConfigExecutor, }; - use openvm_sha256_circuit::{Sha256, Sha256Executor, Sha256ProverExt}; + use openvm_sha2_circuit::{Sha2, Sha2Executor, Sha2ProverExt}; use serde::{Deserialize, Serialize}; #[cfg(feature = "cuda")] use { @@ -128,14 +128,14 @@ mod guest_tests { #[config(generics = true)] pub weierstrass: Rv32WeierstrassConfig, #[extension] - pub sha256: Sha256, + pub sha2: Sha2, } impl EcdsaConfig { pub fn new(curves: Vec) -> Self { Self { weierstrass: Rv32WeierstrassConfig::new(curves), - sha256: Default::default(), + sha2: Default::default(), } } } @@ -179,8 +179,8 @@ mod guest_tests { )?; let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover( - &Sha256ProverExt, - &config.sha256, + &Sha2ProverExt, + &config.sha2, inventory, )?; Ok(chip_complex) @@ -214,8 +214,8 @@ mod guest_tests { )?; let inventory = &mut chip_complex.inventory; VmProverExtension::::extend_prover( - &Sha256ProverExt, - &config.sha256, + &Sha2ProverExt, + &config.sha2, inventory, )?; Ok(chip_complex) @@ -237,7 +237,7 @@ mod guest_tests { .with_extension(Rv32IoTranspilerExtension) .with_extension(EccTranspilerExtension) .with_extension(ModularTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; air_test(EcdsaBuilder, config, openvm_exe); Ok(()) diff --git a/guest-libs/sha2/Cargo.toml b/guest-libs/sha2/Cargo.toml index 573930affb..b00e9f1831 100644 --- a/guest-libs/sha2/Cargo.toml +++ b/guest-libs/sha2/Cargo.toml @@ -10,22 +10,21 @@ repository.workspace = true license.workspace = true [dependencies] -openvm-sha256-guest = { workspace = true } +openvm-sha2-guest = { workspace = true } +openvm-sha2-air = { workspace = true } +sha2 = { workspace = true, default-features = false } [dev-dependencies] openvm-instructions = { workspace = true } openvm-stark-sdk = { workspace = true } openvm-circuit = { workspace = true, features = ["test-utils", "parallel"] } openvm-transpiler = { workspace = true } -openvm-sha256-transpiler = { workspace = true } -openvm-sha256-circuit = { workspace = true } +openvm-sha2-transpiler = { workspace = true } +openvm-sha2-circuit = { workspace = true } openvm-rv32im-transpiler = { workspace = true } openvm-toolchain-tests = { workspace = true } eyre = { workspace = true } -[target.'cfg(not(target_os = "zkvm"))'.dependencies] -sha2 = { workspace = true } - [features] # Internal feature for testing only. -cuda = ["openvm-sha256-circuit/cuda"] \ No newline at end of file +cuda = ["openvm-sha2-circuit/cuda"] \ No newline at end of file diff --git a/guest-libs/sha2/src/host_impl.rs b/guest-libs/sha2/src/host_impl.rs new file mode 100644 index 0000000000..da74208f91 --- /dev/null +++ b/guest-libs/sha2/src/host_impl.rs @@ -0,0 +1,3 @@ +// On a host execution environment, the zkvm impl's input buffering is not necessary, and we can +// use the sha2 crate directly. +pub use sha2::{Sha256, Sha384, Sha512}; diff --git a/guest-libs/sha2/src/lib.rs b/guest-libs/sha2/src/lib.rs index 43d90ba822..2aaeabbfe0 100644 --- a/guest-libs/sha2/src/lib.rs +++ b/guest-libs/sha2/src/lib.rs @@ -1,28 +1,13 @@ #![no_std] -/// The sha256 cryptographic hash function. -#[inline(always)] -pub fn sha256(input: &[u8]) -> [u8; 32] { - let mut output = [0u8; 32]; - set_sha256(input, &mut output); - output -} +pub use sha2::Digest; -/// Sets `output` to the sha256 hash of `input`. -pub fn set_sha256(input: &[u8], output: &mut [u8; 32]) { - #[cfg(not(target_os = "zkvm"))] - { - use sha2::{Digest, Sha256}; - let mut hasher = Sha256::new(); - hasher.update(input); - output.copy_from_slice(hasher.finalize().as_ref()); - } - #[cfg(target_os = "zkvm")] - { - openvm_sha256_guest::zkvm_sha256_impl( - input.as_ptr(), - input.len(), - output.as_mut_ptr() as *mut u8, - ); - } -} +#[cfg(not(target_os = "zkvm"))] +mod host_impl; +#[cfg(target_os = "zkvm")] +mod zkvm_impl; + +#[cfg(not(target_os = "zkvm"))] +pub use host_impl::*; +#[cfg(target_os = "zkvm")] +pub use zkvm_impl::*; diff --git a/guest-libs/sha2/src/zkvm_impl.rs b/guest-libs/sha2/src/zkvm_impl.rs new file mode 100644 index 0000000000..f11187716f --- /dev/null +++ b/guest-libs/sha2/src/zkvm_impl.rs @@ -0,0 +1,267 @@ +use core::cmp::min; + +use sha2::digest::{ + consts::{U32, U64}, + FixedOutput, HashMarker, Output, OutputSizeUser, Update, +}; + +// TODO: the three implementations can be merged into one using a macro + +const SHA256_STATE_BYTES: usize = 32; +const SHA256_BLOCK_BYTES: usize = 64; +const SHA256_DIGEST_BYTES: usize = 32; + +// Initial state for SHA-256 in big-endian bytes +const SHA256_H: [u8; SHA256_STATE_BYTES] = [ + 106, 9, 230, 103, 187, 103, 174, 133, 60, 110, 243, 114, 165, 79, 245, 58, 81, 14, 82, 127, + 155, 5, 104, 140, 31, 131, 217, 171, 91, 224, 205, 25, +]; + +#[derive(Debug, Clone, Copy)] +pub struct Sha256 { + // the current hasher state, in big-endian + state: [u8; SHA256_STATE_BYTES], + // the next block of input + buffer: [u8; SHA256_BLOCK_BYTES], + // idx of next byte to write to buffer + idx: usize, + // accumulated length of the input data + len: usize, +} + +impl Default for Sha256 { + fn default() -> Self { + Self::new() + } +} + +impl Sha256 { + pub fn new() -> Self { + Self { + state: SHA256_H, + buffer: [0; SHA256_BLOCK_BYTES], + idx: 0, + len: 0, + } + } + + pub fn update(&mut self, mut input: &[u8]) { + self.len += input.len(); + while !input.is_empty() { + let to_copy = min(input.len(), SHA256_BLOCK_BYTES - self.idx); + self.buffer[self.idx..self.idx + to_copy].copy_from_slice(&input[..to_copy]); + self.idx += to_copy; + if self.idx == SHA256_BLOCK_BYTES { + self.idx = 0; + self.compress(); + } + input = &input[to_copy..]; + } + } + + pub fn finalize(mut self) -> [u8; SHA256_DIGEST_BYTES] { + self.update(&[0x80]); + while self.idx < SHA256_BLOCK_BYTES - 8 { + self.buffer[self.idx] = 0; + self.idx += 1; + } + self.buffer[SHA256_BLOCK_BYTES - 8..SHA256_BLOCK_BYTES] + .copy_from_slice(&(self.len as u64).to_be_bytes()); + self.compress(); + self.state + } + + fn compress(&mut self) { + openvm_sha2_guest::zkvm_sha256_impl( + self.state.as_ptr(), + self.buffer.as_ptr(), + self.state.as_mut_ptr() as *mut u8, + ); + } +} + +// We will implement FixedOutput, Default, Update, and HashMarker for Sha256 so that +// the blanket implementation of sha2::Digest is available. +// See: https://docs.rs/sha2/latest/sha2/trait.Digest.html#impl-Digest-for-D +impl Update for Sha256 { + fn update(&mut self, input: &[u8]) { + self.update(input); + } +} + +// OutputSizeUser is required for FixedOutput +// See: https://docs.rs/digest/0.10.7/digest/trait.FixedOutput.html +impl OutputSizeUser for Sha256 { + type OutputSize = U32; +} + +impl FixedOutput for Sha256 { + fn finalize_into(self, out: &mut Output) { + out.copy_from_slice(&self.finalize()); + } +} + +impl HashMarker for Sha256 {} + +const SHA512_STATE_BYTES: usize = 64; +const SHA512_BLOCK_BYTES: usize = 128; +const SHA512_DIGEST_BYTES: usize = 64; + +// Initial state for SHA-512 in big-endian bytes +const SHA512_H: [u8; SHA512_STATE_BYTES] = [ + 106, 9, 230, 103, 243, 188, 201, 8, 187, 103, 174, 133, 132, 202, 167, 59, 60, 110, 243, 114, + 254, 148, 248, 43, 165, 79, 245, 58, 95, 29, 54, 241, 81, 14, 82, 127, 173, 230, 130, 209, 155, + 5, 104, 140, 43, 62, 108, 31, 31, 131, 217, 171, 251, 65, 189, 107, 91, 224, 205, 25, 19, 126, + 33, 121, +]; + +#[derive(Debug, Clone, Copy)] +pub struct Sha512 { + // the current hasher state + state: [u8; SHA512_STATE_BYTES], + // the next block of input + buffer: [u8; SHA512_BLOCK_BYTES], + // idx of next byte to write to buffer + idx: usize, + // accumulated length of the input data + len: usize, +} + +impl Default for Sha512 { + fn default() -> Self { + Self::new() + } +} + +impl Sha512 { + pub fn new() -> Self { + Self { + state: SHA512_H, + buffer: [0; SHA512_BLOCK_BYTES], + idx: 0, + len: 0, + } + } + + pub fn update(&mut self, mut input: &[u8]) { + self.len += input.len(); + while !input.is_empty() { + let to_copy = min(input.len(), SHA512_BLOCK_BYTES - self.idx); + self.buffer[self.idx..self.idx + to_copy].copy_from_slice(&input[..to_copy]); + self.idx += to_copy; + if self.idx == SHA512_BLOCK_BYTES { + self.idx = 0; + self.compress(); + } + input = &input[to_copy..]; + } + } + + pub fn finalize(mut self) -> [u8; SHA512_DIGEST_BYTES] { + self.update(&[0x80]); + while self.idx < SHA512_BLOCK_BYTES - 8 { + self.buffer[self.idx] = 0; + self.idx += 1; + } + self.buffer[SHA512_BLOCK_BYTES - 16..SHA512_BLOCK_BYTES] + .copy_from_slice(&(self.len as u128).to_be_bytes()); + self.compress(); + self.state + } + + fn compress(&mut self) { + openvm_sha2_guest::zkvm_sha512_impl( + self.state.as_ptr(), + self.buffer.as_ptr(), + self.state.as_mut_ptr() as *mut u8, + ); + } +} + +// We will implement FixedOutput, Default, Update, and HashMarker for Sha512 so that +// the blanket implementation of sha2::Digest is available. +// See: https://docs.rs/sha2/latest/sha2/trait.Digest.html#impl-Digest-for-D +impl Update for Sha512 { + fn update(&mut self, input: &[u8]) { + self.update(input); + } +} + +// OutputSizeUser is required for FixedOutput +// See: https://docs.rs/digest/0.10.7/digest/trait.FixedOutput.html +impl OutputSizeUser for Sha512 { + type OutputSize = U64; +} + +impl FixedOutput for Sha512 { + fn finalize_into(self, out: &mut Output) { + out.copy_from_slice(&self.finalize()); + } +} + +impl HashMarker for Sha512 {} + +const SHA384_STATE_BYTES: usize = 64; +const SHA384_BLOCK_BYTES: usize = 128; +const SHA384_DIGEST_BYTES: usize = 48; + +const SHA384_H: [u8; SHA384_STATE_BYTES] = [ + 203, 187, 157, 93, 193, 5, 158, 216, 98, 154, 41, 42, 54, 124, 213, 7, 145, 89, 1, 90, 48, 112, + 221, 23, 21, 47, 236, 216, 247, 14, 89, 57, 103, 51, 38, 103, 255, 192, 11, 49, 142, 180, 74, + 135, 104, 88, 21, 17, 219, 12, 46, 13, 100, 249, 143, 167, 71, 181, 72, 29, 190, 250, 79, 164, +]; + +#[derive(Debug, Clone, Copy)] +pub struct Sha384 { + inner: Sha512, +} + +impl Default for Sha384 { + fn default() -> Self { + Self::new() + } +} + +impl Sha384 { + pub fn new() -> Self { + let mut inner = Sha512::new(); + inner.state = SHA384_H; + Self { inner } + } + + pub fn update(&mut self, input: &[u8]) { + self.inner.update(input); + } + + pub fn finalize(self) -> [u8; SHA384_DIGEST_BYTES] { + let digest = self.inner.finalize(); + digest[..SHA384_DIGEST_BYTES].try_into().unwrap() + } + + fn compress(&mut self) { + self.inner.compress(); + } +} + +// We will implement FixedOutput, Default, Update, and HashMarker for Sha384 so that +// the blanket implementation of sha2::Digest is available. +// See: https://docs.rs/sha2/latest/sha2/trait.Digest.html#impl-Digest-for-D +impl Update for Sha384 { + fn update(&mut self, input: &[u8]) { + self.update(input); + } +} + +// OutputSizeUser is required for FixedOutput +// See: https://docs.rs/digest/0.10.7/digest/trait.FixedOutput.html +impl OutputSizeUser for Sha384 { + type OutputSize = U64; +} + +impl FixedOutput for Sha384 { + fn finalize_into(self, out: &mut Output) { + out.copy_from_slice(&self.finalize()); + } +} + +impl HashMarker for Sha384 {} diff --git a/guest-libs/sha2/tests/lib.rs b/guest-libs/sha2/tests/lib.rs index adfae8e764..26319f7e4d 100644 --- a/guest-libs/sha2/tests/lib.rs +++ b/guest-libs/sha2/tests/lib.rs @@ -6,8 +6,8 @@ mod tests { use openvm_rv32im_transpiler::{ Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension, }; - use openvm_sha256_circuit::{Sha256Rv32Builder, Sha256Rv32Config}; - use openvm_sha256_transpiler::Sha256TranspilerExtension; + use openvm_sha2_circuit::{Sha2Rv32Builder, Sha2Rv32Config}; + use openvm_sha2_transpiler::Sha2TranspilerExtension; use openvm_stark_sdk::p3_baby_bear::BabyBear; use openvm_toolchain_tests::{build_example_program_at_path, get_programs_dir}; use openvm_transpiler::{transpiler::Transpiler, FromElf}; @@ -15,19 +15,19 @@ mod tests { type F = BabyBear; #[test] - fn test_sha256() -> Result<()> { - let config = Sha256Rv32Config::default(); + fn test_sha2() -> Result<()> { + let config = Sha2Rv32Config::default(); let elf = - build_example_program_at_path(get_programs_dir!("tests/programs"), "sha", &config)?; + build_example_program_at_path(get_programs_dir!("tests/programs"), "sha2", &config)?; let openvm_exe = VmExe::from_elf( elf, Transpiler::::default() .with_extension(Rv32ITranspilerExtension) .with_extension(Rv32MTranspilerExtension) .with_extension(Rv32IoTranspilerExtension) - .with_extension(Sha256TranspilerExtension), + .with_extension(Sha2TranspilerExtension), )?; - air_test(Sha256Rv32Builder, config, openvm_exe); + air_test(Sha2Rv32Builder, config, openvm_exe); Ok(()) } } diff --git a/guest-libs/sha2/tests/programs/Cargo.toml b/guest-libs/sha2/tests/programs/Cargo.toml index df13f8dfc7..c197564ec0 100644 --- a/guest-libs/sha2/tests/programs/Cargo.toml +++ b/guest-libs/sha2/tests/programs/Cargo.toml @@ -8,12 +8,12 @@ edition = "2021" openvm = { path = "../../../../crates/toolchain/openvm" } openvm-platform = { path = "../../../../crates/toolchain/platform" } openvm-sha2 = { path = "../../" } - hex = { version = "0.4.3", default-features = false, features = ["alloc"] } serde = { version = "1.0", default-features = false, features = [ "alloc", "derive", ] } +hex-literal = { version = "1.0.0" } [features] default = [] diff --git a/guest-libs/sha2/tests/programs/examples/sha.rs b/guest-libs/sha2/tests/programs/examples/sha.rs deleted file mode 100644 index ebfd50cbee..0000000000 --- a/guest-libs/sha2/tests/programs/examples/sha.rs +++ /dev/null @@ -1,29 +0,0 @@ -#![cfg_attr(not(feature = "std"), no_main)] -#![cfg_attr(not(feature = "std"), no_std)] - -extern crate alloc; - -use alloc::vec::Vec; -use core::hint::black_box; - -use hex::FromHex; -use openvm_sha2::sha256; - -openvm::entry!(main); - -pub fn main() { - let test_vectors = [ - ("", "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"), - ("98c1c0bdb7d5fea9a88859f06c6c439f", "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05"), - ("5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7"), - ("9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23") - ]; - for (input, expected_output) in test_vectors.iter() { - let input = Vec::from_hex(input).unwrap(); - let expected_output = Vec::from_hex(expected_output).unwrap(); - let output = sha256(&black_box(input)); - if output != *expected_output { - panic!(); - } - } -} diff --git a/guest-libs/sha2/tests/programs/examples/sha2.rs b/guest-libs/sha2/tests/programs/examples/sha2.rs new file mode 100644 index 0000000000..91eb652698 --- /dev/null +++ b/guest-libs/sha2/tests/programs/examples/sha2.rs @@ -0,0 +1,91 @@ +#![cfg_attr(not(feature = "std"), no_main)] +#![cfg_attr(not(feature = "std"), no_std)] + +extern crate alloc; + +use alloc::vec::Vec; +use core::hint::black_box; + +use hex::FromHex; +use openvm_sha2::{Sha256, Sha384, Sha512}; + +openvm::entry!(main); + +struct ShaTestVector { + input: &'static str, + expected_output_sha256: &'static str, + expected_output_sha512: &'static str, + expected_output_sha384: &'static str, +} + +pub fn main() { + let test_vectors = [ + ShaTestVector { + input: "", + expected_output_sha256: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expected_output_sha512: "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce47d0d13c5d85f2b0ff8318d2877eec2f63b931bd47417a81a538327af927da3e", + expected_output_sha384: "38b060a751ac96384cd9327eb1b1e36a21fdb71114be07434c0cc7bf63f6e1da274edebfe76f65fbd51ad2f14898b95b", + }, + ShaTestVector { + input: "98c1c0bdb7d5fea9a88859f06c6c439f", + expected_output_sha256: "b6b2c9c9b6f30e5c66c977f1bd7ad97071bee739524aecf793384890619f2b05", + expected_output_sha512: "eb576959c531f116842c0cc915a29c8f71d7a285c894c349b83469002ef093d51f9f14ce4248488bff143025e47ed27c12badb9cd43779cb147408eea062d583", + expected_output_sha384: "63e3061aab01f335ea3a4e617b9d14af9b63a5240229164ee962f6d5335ff25f0f0bf8e46723e83c41b9d17413b6a3c7", + }, + ShaTestVector { + input: "5b58f4163e248467cc1cd3eecafe749e8e2baaf82c0f63af06df0526347d7a11327463c115210a46b6740244eddf370be89c", + expected_output_sha256: "ac0e25049870b91d78ef6807bb87fce4603c81abd3c097fba2403fd18b6ce0b7", + expected_output_sha512: "a20d5fb14814d045a7d2861e80d2b688f1cd1daaba69e6bb1cc5233f514141ea4623b3373af702e78e3ec5dc8c1b716a37a9a2f5fbc9493b9df7043f5e99a8da", + expected_output_sha384: "eac4b72b0540486bc088834860873338e31e9e4062532bf509191ef63b9298c67db5654a28fe6f07e4cc6ff466d1be24", + }, + ShaTestVector { + input: "9ad198539e3160194f38ac076a782bd5210a007560d1fce9ef78f8a4a5e4d78c6b96c250cff3520009036e9c6087d5dab587394edda862862013de49a12072485a6c01165ec0f28ffddf1873fbd53e47fcd02fb6a5ccc9622d5588a92429c663ce298cb71b50022fc2ec4ba9f5bbd250974e1a607b165fee16e8f3f2be20d7348b91a2f518ce928491900d56d9f86970611580350cee08daea7717fe28a73b8dcfdea22a65ed9f5a09198de38e4e4f2cc05b0ba3dd787a5363ab6c9f39dcb66c1a29209b1d6b1152769395df8150b4316658ea6ab19af94903d643fcb0ae4d598035ebe73c8b1b687df1ab16504f633c929569c6d0e5fae6eea43838fbc8ce2c2b43161d0addc8ccf945a9c4e06294e56a67df0000f561f61b630b1983ba403e775aaeefa8d339f669d1e09ead7eae979383eda983321e1743e5404b4b328da656de79ff52d179833a6bd5129f49432d74d001996c37c68d9ab49fcff8061d193576f396c20e1f0d9ee83a51290ba60efa9c3cb2e15b756321a7ca668cdbf63f95ec33b1c450aa100101be059dc00077245b25a6a66698dee81953ed4a606944076e2858b1420de0095a7f60b08194d6d9a997009d345c71f63a7034b976e409af8a9a040ac7113664609a7adedb76b2fadf04b0348392a1650526eb2a4d6ed5e4bbcda8aabc8488b38f4f5d9a398103536bb8250ed82a9b9825f7703c263f9e", + expected_output_sha256: "080ad71239852124fc26758982090611b9b19abf22d22db3a57f67a06e984a23", + expected_output_sha512: "8d215ee6dc26757c210db0dd00c1c6ed16cc34dbd4bb0fa10c1edb6b62d5ab16aea88c881001b173d270676daf2d6381b5eab8711fa2f5589c477c1d4b84774f", + expected_output_sha384: "904a90010d772a904a35572fdd4bdf1dd253742e47872c8a18e2255f66fa889e44781e65487a043f435daa53c496a53e", + } + ]; + + for ( + i, + ShaTestVector { + input, + expected_output_sha256, + expected_output_sha512, + expected_output_sha384, + }, + ) in test_vectors.iter().enumerate() + { + let input = Vec::from_hex(input).unwrap(); + let expected_output_sha256 = Vec::from_hex(expected_output_sha256).unwrap(); + let mut hasher = Sha256::new(); + hasher.update(black_box(&input)); + let output = hasher.finalize(); + if output != *expected_output_sha256 { + panic!( + "sha256 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha256, output + ); + } + let expected_output_sha512 = Vec::from_hex(expected_output_sha512).unwrap(); + let mut hasher = Sha512::new(); + hasher.update(black_box(&input)); + let output = hasher.finalize(); + if output != *expected_output_sha512 { + panic!( + "sha512 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha512, output + ); + } + let expected_output_sha384 = Vec::from_hex(expected_output_sha384).unwrap(); + let mut hasher = Sha384::new(); + hasher.update(black_box(&input)); + let output = hasher.finalize(); + if output != *expected_output_sha384 { + panic!( + "sha384 test {i} failed on input: {:?}.\nexpected: {:?},\ngot: {:?}", + input, expected_output_sha384, output + ); + } + } +}