diff --git a/.editorconfig b/.editorconfig index be26136a..4a853b30 100644 --- a/.editorconfig +++ b/.editorconfig @@ -19,6 +19,9 @@ quote_type = single [*.{json,yaml,yml,toml}] indent_size = 2 +[*.{md,mdx}] +indent_size = 2 + [*{toml}] end_of_line = lf indent_size = 2 diff --git a/.gitignore b/.gitignore index e61287c0..94736155 100644 --- a/.gitignore +++ b/.gitignore @@ -1,65 +1,68 @@ -# Artifacts +# General +**/.artifacts/tmp/ +**/.docker/data -**/.artifacts/data/ -**/.docker/data/ +**/.cache/ +**/.DS_Store +*.bk +*.bk.* -## Caches +*.bak +*.bak.* -**/.cache/ +*.csv +*.csv.* -# Configuration Files +*.db +*.db.* -**/config.* -**/*.config.* +*.db-*.* -**/*.env -**/*.env.* +*.error +*.error.* -### Exceptions -!**/default.config.* -!**/*.config.cjs -!**/*.config.js -!**/*.config.mjs -!**/config.py -!**/config.rs +*.gz +*.gz.* -!**/example.env.* -!**/*.env.example -!**/*.env.default +*.log +*.log.* -# Dev +*.sqlite +*.sqlite.* -### Idea +*.zip +*.zip.* -**/.idea/ +## Development Environments -### vscode - -**/.vscode/ +### IntelliJ Idea +**/.idea/ -# File Extensions +**/*.iml +**/*.iml.* -**/*.lock -**/*.lock.* +**/*.iws +**/*.iws.* -**/*-lock.* +### Visual Studio Code +**/.vscode/ -**/*.log -**/*.log.* +!.vscode/settings.* +!.vscode/launch.* -### Data Files +## Operating Systems -**/*.csv -**/*.csv.* +### Nix +!flake.lock -**/*.db -**/*.db.* +### Windows (WSL2) +**/*:Zone.Identifier -**/*.db-*.* +**/.artifacts/data/ +**/.docker/data/ -**/*.zip -**/*.zip.* +# Language Specific ## Rust **/debug/ @@ -70,17 +73,6 @@ !Cargo.lock -## Node -**/build/ -**/debug/ -**/dist/ -**/node_modules/ - -### SvelteKit -**/__sapper__/ -**/.DS_STORE/ -**/.svelte-kit/ - ## Python **/__pycache__/ **/.pytest_cache/ @@ -106,7 +98,53 @@ **/*.whl **/*.whl.* -## Operating Systems +## Node.js +**/.DS_STORE +**/.pnp +**/.pnp.* +**/.vercel +**/build +**/debug +**/dist +**/node_modules -### Windows (WSL2) -**/*:Zone.Identifier \ No newline at end of file +### TypeScript +**/*.tsbuildinfo + +### npm +!**/package-lock.json +**/npm-debug.log + +### yarn +**/.yarn/install-state.gz +**/yarn-error.log +**/yarn-debug.log + +### SvelteKit +**/__sapper__/ +**/.svelte-kit/ + +### Deno +**/deno.json +**/deno.jsonc + +### Next.js +**/.next +**/out + +**/next-env.d.ts + +### Jest +**/.swc +**/coverage + +### OpenNext (Cloudflare) +**/.open-next +**/.wrangler +**/.vercel + +### Supabase +**/supabase/temp + +### Vercel +**/.vercel \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..eeeccbf4 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "rust-analyzer.check.features": [ + "default", + "full" + ] +} \ No newline at end of file diff --git a/CODEOWNERS b/CODEOWNERS index 5feac4ca..b4dd9c78 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -1,3 +1,3 @@ -# the default owners of all files and directories within the repository -# except those that are explicitly assigned to other owners. -* @FL03 @scattered-systems +# *** ROOT *** +# Unless otherwise specified, all files in this repository are owned by the following user(s): +* @FL03 \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index d6d6b333..b670e9cf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "android_system_properties" version = "0.1.5" @@ -100,9 +115,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.47" +version = "1.2.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd405d82c84ff7f35739f175f67d8b9fb7687a0e84ccdc78bd3568839827cf07" +checksum = "90583009037521a116abf44494efecd645ba48b6622457080f080b85544e2215" dependencies = [ "find-msvc-tools", "shlex", @@ -180,7 +195,7 @@ checksum = "a1d728cc89cf3aee9ff92b05e62b19ee65a02b5702cff7d5a377e32c6ae29d8d" [[package]] name = "concision" -version = "0.2.9" +version = "0.3.0" dependencies = [ "anyhow", "approx", @@ -189,32 +204,27 @@ dependencies = [ "concision-derive", "concision-macros", "criterion", - "lazy_static", "ndarray", - "num", - "serde", "tracing", "tracing-subscriber", ] [[package]] name = "concision-core" -version = "0.2.9" +version = "0.3.0" dependencies = [ "anyhow", "approx", "concision-init", "concision-params", "concision-traits", - "getrandom", + "hashbrown 0.16.1", "lazy_static", "ndarray", "num-complex", "num-integer", "num-traits", "paste", - "rand 0.9.2", - "rand_distr", "rayon", "rustfft", "serde", @@ -229,7 +239,7 @@ dependencies = [ [[package]] name = "concision-data" -version = "0.2.9" +version = "0.3.0" dependencies = [ "approx", "concision-core", @@ -248,7 +258,7 @@ dependencies = [ [[package]] name = "concision-derive" -version = "0.2.9" +version = "0.3.0" dependencies = [ "proc-macro2", "quote", @@ -257,7 +267,7 @@ dependencies = [ [[package]] name = "concision-ext" -version = "0.2.9" +version = "0.3.0" dependencies = [ "anyhow", "approx", @@ -279,10 +289,10 @@ dependencies = [ [[package]] name = "concision-init" -version = "0.2.9" +version = "0.3.0" dependencies = [ "approx", - "getrandom", + "getrandom 0.3.4", "lazy_static", "ndarray", "num", @@ -302,7 +312,7 @@ dependencies = [ [[package]] name = "concision-macros" -version = "0.2.9" +version = "0.3.0" dependencies = [ "proc-macro2", "quote", @@ -311,19 +321,19 @@ dependencies = [ [[package]] name = "concision-params" -version = "0.2.9" +version = "0.3.0" dependencies = [ + "anyhow", "approx", "concision-init", "concision-traits", - "getrandom", - "lazy_static", + "getrandom 0.3.4", "ndarray", "num-complex", "num-traits", "rand 0.9.2", "rand_distr", - "rayon", + "rayon-core", "serde", "serde_derive", "serde_json", @@ -333,10 +343,11 @@ dependencies = [ [[package]] name = "concision-traits" -version = "0.2.9" +version = "0.3.0" dependencies = [ + "anyhow", "approx", - "getrandom", + "getrandom 0.3.4", "ndarray", "num-complex", "num-integer", @@ -348,6 +359,16 @@ dependencies = [ "variants", ] +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.7" @@ -356,10 +377,11 @@ checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" [[package]] name = "criterion" -version = "0.7.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1c047a62b0cc3e145fa84415a3191f628e980b194c2755aa12300a4e6cbd928" +checksum = "4d883447757bb0ee46f233e9dc22eb84d93a9508c9b868687b274fc431d886bf" dependencies = [ + "alloca", "anes", "cast", "ciborium", @@ -368,6 +390,7 @@ dependencies = [ "itertools", "num-traits", "oorandom", + "page_size", "plotters", "rayon", "regex", @@ -379,9 +402,9 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.6.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b1bcc0dc7dfae599d84ad0b1a55f80cde8af3725da8313b528da95ef783e338" +checksum = "ed943f81ea2faa8dcecbbfa50164acf95d555afec96a27871663b300e387b2e4" dependencies = [ "cast", "itertools", @@ -480,12 +503,37 @@ version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +[[package]] +name = "encoding_rs" +version = "0.8.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" +dependencies = [ + "cfg-if", +] + [[package]] name = "equivalent" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "errno" +version = "0.3.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" +dependencies = [ + "libc", + "windows-sys 0.61.2", +] + +[[package]] +name = "fastrand" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + [[package]] name = "find-msvc-tools" version = "0.1.5" @@ -498,6 +546,27 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + +[[package]] +name = "foreign-types" +version = "0.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" +dependencies = [ + "foreign-types-shared", +] + +[[package]] +name = "foreign-types-shared" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" + [[package]] name = "form_urlencoded" version = "1.2.2" @@ -522,6 +591,12 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-sink" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e575fab7d1e0dcb8d0c7bcf9a63ee213816ab51902e6d244a95819acacf1d4f7" + [[package]] name = "futures-task" version = "0.3.31" @@ -540,6 +615,17 @@ dependencies = [ "pin-utils", ] +[[package]] +name = "getrandom" +version = "0.2.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "335ff9f135e4384c8150d6f27c6daed433577f86b4750418338c01a1a2528592" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + [[package]] name = "getrandom" version = "0.3.4" @@ -554,6 +640,25 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "h2" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3c0b69cfcb4e1b9f1bf2f53f95f766e4661169728ec61cd3fe5a0166f2d1386" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http", + "indexmap 2.12.1", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "half" version = "2.7.1" @@ -576,6 +681,15 @@ name = "hashbrown" version = "0.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "841d1cc9bed7f9236f321df977030373f4a4163ae1a7dbfe1a51a2c1a51d9100" +dependencies = [ + "allocator-api2", + "equivalent", + "foldhash", + "rayon", + "rustc-std-workspace-alloc", + "serde", + "serde_core", +] [[package]] name = "heck" @@ -638,6 +752,7 @@ dependencies = [ "bytes", "futures-channel", "futures-core", + "h2", "http", "http-body", "httparse", @@ -649,11 +764,43 @@ dependencies = [ "want", ] +[[package]] +name = "hyper-rustls" +version = "0.27.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3c93eb611681b207e1fe55d5a71ecf91572ec8a6705cdb6857f7d8d5242cf58" +dependencies = [ + "http", + "hyper", + "hyper-util", + "rustls", + "rustls-pki-types", + "tokio", + "tokio-rustls", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper", + "hyper-util", + "native-tls", + "tokio", + "tokio-native-tls", + "tower-service", +] + [[package]] name = "hyper-util" -version = "0.1.18" +version = "0.1.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52e9a2a24dc5c6821e71a7030e1e14b7b632acac55c40e9d2e082c621261bb56" +checksum = "727805d60e7938b76b826a6ef209eb70eaa1812794f9424d4a4e2d740662df5f" dependencies = [ "base64", "bytes", @@ -668,9 +815,11 @@ dependencies = [ "percent-encoding", "pin-project-lite", "socket2", + "system-configuration", "tokio", "tower-service", "tracing", + "windows-registry", ] [[package]] @@ -861,9 +1010,9 @@ checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" [[package]] name = "js-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b011eec8cc36da2aab2d5cff675ec18454fad408585853910a202391cf9f8e65" +checksum = "464a3709c7f55f1f721e5389aa6ea4e3bc6aba669353300af094b29ffbdde1d8" dependencies = [ "once_cell", "wasm-bindgen", @@ -877,9 +1026,9 @@ checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "libc" -version = "0.2.177" +version = "0.2.178" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2874a2af47a2325c2001a6e6fad9b16a53b802102b528163885171cf92b15976" +checksum = "37c93d8daa9d8a012fd8ab92f088405fb202ea0b6ab73ee2482ae66af4f42091" [[package]] name = "libm" @@ -887,6 +1036,12 @@ version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" +[[package]] +name = "linux-raw-sys" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" + [[package]] name = "litemap" version = "0.8.1" @@ -895,9 +1050,9 @@ checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" [[package]] name = "log" -version = "0.4.28" +version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34080505efa8e45a4b816c349525ebe327ceaa8559756f0356cba97ef3bf7432" +checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" [[package]] name = "matchers" @@ -924,17 +1079,40 @@ version = "2.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f52b00d39961fc5b2736ea853c9cc86238e165017a493d1d5c8eac6bdc4cc273" +[[package]] +name = "mime" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" + [[package]] name = "mio" -version = "1.1.0" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69d83b0086dc8ecf3ce9ae2874b2d1290252e2a30720bea58a5c6639b0092873" +checksum = "a69bcab0ad47271a0234d9422b131806bf3968021e5dc9328caf2d4cd58557fc" dependencies = [ "libc", "wasi", "windows-sys 0.61.2", ] +[[package]] +name = "native-tls" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87de3442987e9dbec73158d5c715e7ad9072fda936bb03d19d7fa10e00520f0e" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "ndarray" version = "0.17.1" @@ -1061,6 +1239,60 @@ version = "11.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6790f58c7ff633d8771f42965289203411a5e5c68388703c06e14f24770b41e" +[[package]] +name = "openssl" +version = "0.10.75" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08838db121398ad17ab8531ce9de97b244589089e290a384c900cb9ff7434328" +dependencies = [ + "bitflags", + "cfg-if", + "foreign-types", + "libc", + "once_cell", + "openssl-macros", + "openssl-sys", +] + +[[package]] +name = "openssl-macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "openssl-probe" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d05e27ee213611ffe7d6348b942e8f942b37114c00cc03cec254295a4a17852e" + +[[package]] +name = "openssl-sys" +version = "0.9.111" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82cab2d520aa75e3c58898289429321eb788c3106963d0dc886ec7a5f4adc321" +dependencies = [ + "cc", + "libc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "paste" version = "1.0.15" @@ -1085,6 +1317,12 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "pkg-config" +version = "0.3.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c" + [[package]] name = "plotters" version = "0.3.7" @@ -1227,7 +1465,7 @@ version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" dependencies = [ - "getrandom", + "getrandom 0.3.4", "serde", ] @@ -1302,27 +1540,35 @@ checksum = "7a2d987857b319362043e95f5353c0535c1f58eec5336fdfcf626430af7def58" [[package]] name = "reqwest" -version = "0.12.24" +version = "0.12.25" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d0946410b9f7b082a427e4ef5c8ff541a88b357bc6c637c40db3a68ac70a36f" +checksum = "b6eff9328d40131d43bd911d42d79eb6a47312002a4daefc9e37f17e74a7701a" dependencies = [ "base64", "bytes", + "encoding_rs", "futures-core", + "h2", "http", "http-body", "http-body-util", "hyper", + "hyper-rustls", + "hyper-tls", "hyper-util", "js-sys", "log", + "mime", + "native-tls", "percent-encoding", "pin-project-lite", + "rustls-pki-types", "serde", "serde_json", "serde_urlencoded", "sync_wrapper", "tokio", + "tokio-native-tls", "tower", "tower-http", "tower-service", @@ -1332,6 +1578,26 @@ dependencies = [ "web-sys", ] +[[package]] +name = "ring" +version = "0.17.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4689e6c2294d81e88dc6261c768b63bc4fcdb852be6d1352498b114f61383b7" +dependencies = [ + "cc", + "cfg-if", + "getrandom 0.2.16", + "libc", + "untrusted", + "windows-sys 0.52.0", +] + +[[package]] +name = "rustc-std-workspace-alloc" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d441c3b2ebf55cebf796bfdc265d67fa09db17b7bb6bd4be75c509e1e8fec3" + [[package]] name = "rustfft" version = "6.4.1" @@ -1346,6 +1612,52 @@ dependencies = [ "transpose", ] +[[package]] +name = "rustix" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cd15f8a2c5551a84d56efdc1cd049089e409ac19a3072d5037a17fd70719ff3e" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.61.2", +] + +[[package]] +name = "rustls" +version = "0.23.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "533f54bc6a7d4f647e46ad909549eda97bf5afc1585190ef692b4286b198bd8f" +dependencies = [ + "once_cell", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "708c0f9d5f54ba0272468c1d306a52c495b31fa155e91bc25371e6df7996908c" +dependencies = [ + "zeroize", +] + +[[package]] +name = "rustls-webpki" +version = "0.103.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ffdfa2f5286e2247234e03f680868ac2815974dc39e00ea15adc445d0aafe52" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.22" @@ -1367,6 +1679,38 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "schannel" +version = "0.1.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d81b926048e76efe18581bf793546b4c0eaf8448d72be8de2bbee5fd166e1" +dependencies = [ + "windows-sys 0.61.2", +] + +[[package]] +name = "security-framework" +version = "2.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" +dependencies = [ + "bitflags", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cc1f0cbffaac4852523ce30d8bd3c5cdc873501d96ff467ca09b6767bb8cd5c0" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" version = "1.0.228" @@ -1467,6 +1811,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "slab" +version = "0.4.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a2ae44ef20feb57a68b23d846850f861394c2e02dc425a50098ae8c90267589" + [[package]] name = "smallvec" version = "1.15.1" @@ -1533,6 +1883,12 @@ dependencies = [ "syn", ] +[[package]] +name = "subtle" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" + [[package]] name = "syn" version = "2.0.111" @@ -1564,6 +1920,40 @@ dependencies = [ "syn", ] +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags", + "core-foundation", + "system-configuration-sys", +] + +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + +[[package]] +name = "tempfile" +version = "3.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d31c77bdf42a745371d260a26ca7163f1e0924b64afa0b688e61b5a9fa02f16" +dependencies = [ + "fastrand", + "getrandom 0.3.4", + "once_cell", + "rustix", + "windows-sys 0.61.2", +] + [[package]] name = "thiserror" version = "2.0.17" @@ -1650,6 +2040,7 @@ version = "1.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff360e02eab121e0bc37a2d3b4d4dc622e6eda3a8e5253d5435ecf5bd4c68408" dependencies = [ + "bytes", "libc", "mio", "pin-project-lite", @@ -1657,6 +2048,39 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" +dependencies = [ + "rustls", + "tokio", +] + +[[package]] +name = "tokio-util" +version = "0.7.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2efa149fe76073d6e8fd97ef4f4eca7b67f599660115591483572e406e165594" +dependencies = [ + "bytes", + "futures-core", + "futures-sink", + "pin-project-lite", + "tokio", +] + [[package]] name = "tower" version = "0.5.2" @@ -1674,9 +2098,9 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.7" +version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9cf146f99d442e8e68e585f5d798ccd3cad9a7835b917e09728880a862706456" +checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" dependencies = [ "bitflags", "bytes", @@ -1704,9 +2128,9 @@ checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" -version = "0.1.41" +version = "0.1.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +checksum = "2d15d90a0b5c19378952d479dc858407149d7bb45a14de0142f6c534b16fc647" dependencies = [ "log", "pin-project-lite", @@ -1732,22 +2156,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a04e24fab5c89c6a36eb8558c9656f30d81de51dfa4d3b45f26b21d61fa0a6c" dependencies = [ "once_cell", + "valuable", +] + +[[package]] +name = "tracing-log" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee855f1f400bd0e5c02d150ae5de3840039a3f54b025156404e34c23c03f47c3" +dependencies = [ + "log", + "once_cell", + "tracing-core", ] [[package]] name = "tracing-subscriber" -version = "0.3.20" +version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" +checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ "matchers", "nu-ansi-term", "once_cell", "regex-automata", "sharded-slab", + "smallvec", "thread_local", "tracing", "tracing-core", + "tracing-log", ] [[package]] @@ -1772,6 +2210,12 @@ version = "1.0.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9312f7c4f6ff9069b165498234ce8be658059c6728633667c526e27dc2cf1df5" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "url" version = "2.5.7" @@ -1790,6 +2234,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" +[[package]] +name = "valuable" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba73ea9cf16a25df0c8caa16c51acb937d5712a8429db78a3ee29d5dcacd3a65" + [[package]] name = "variants" version = "0.0.1" @@ -1824,6 +2274,12 @@ dependencies = [ "syn", ] +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + [[package]] name = "walkdir" version = "2.5.0" @@ -1860,9 +2316,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da95793dfc411fbbd93f5be7715b0578ec61fe87cb1a42b12eb625caa5c5ea60" +checksum = "0d759f433fa64a2d763d1340820e46e111a7a5ab75f993d1852d70b03dbb80fd" dependencies = [ "cfg-if", "once_cell", @@ -1873,9 +2329,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.55" +version = "0.4.56" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "551f88106c6d5e7ccc7cd9a16f312dd3b5d36ea8b4954304657d5dfba115d4a0" +checksum = "836d9622d604feee9e5de25ac10e3ea5f2d65b41eac0d9ce72eb5deae707ce7c" dependencies = [ "cfg-if", "js-sys", @@ -1886,9 +2342,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "04264334509e04a7bf8690f2384ef5265f05143a4bff3889ab7a3269adab59c2" +checksum = "48cb0d2638f8baedbc542ed444afc0644a29166f1595371af4fecf8ce1e7eeb3" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -1896,9 +2352,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "420bc339d9f322e562942d52e115d57e950d12d88983a14c79b86859ee6c7ebc" +checksum = "cefb59d5cd5f92d9dcf80e4683949f15ca4b511f4ac0a6e14d4e1ac60c6ecd40" dependencies = [ "bumpalo", "proc-macro2", @@ -1909,9 +2365,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.105" +version = "0.2.106" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "76f218a38c84bcb33c25ec7059b07847d465ce0e0a76b995e134a45adcb6af76" +checksum = "cbc538057e648b67f72a982e708d485b2efa771e1ac05fec311f9f63e5800db4" dependencies = [ "unicode-ident", ] @@ -1929,14 +2385,30 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.82" +version = "0.3.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a1f95c0d03a47f4ae1f7a64643a6bb97465d9b740f0fa8f90ea33915c99a9a1" +checksum = "9b32828d774c412041098d182a8b38b16ea816958e07cf40eec2bc080ae137ac" dependencies = [ "js-sys", "wasm-bindgen", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.11" @@ -1946,6 +2418,12 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + [[package]] name = "windows-core" version = "0.62.2" @@ -1987,6 +2465,17 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-registry" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "02752bf7fbdcce7f2a27a742f798510f3e5ad88dbe84871e5168e2120c3d5720" +dependencies = [ + "windows-link", + "windows-result", + "windows-strings", +] + [[package]] name = "windows-result" version = "0.4.1" @@ -2005,13 +2494,22 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.6", +] + [[package]] name = "windows-sys" version = "0.60.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" dependencies = [ - "windows-targets", + "windows-targets 0.53.5", ] [[package]] @@ -2023,6 +2521,22 @@ dependencies = [ "windows-link", ] +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", + "windows_i686_gnullvm 0.52.6", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", +] + [[package]] name = "windows-targets" version = "0.53.5" @@ -2030,58 +2544,106 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4945f9f551b88e0d65f3db0bc25c33b8acea4d9e41163edf90dcd0b19f9069f3" dependencies = [ "windows-link", - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", + "windows_aarch64_gnullvm 0.53.1", + "windows_aarch64_msvc 0.53.1", + "windows_i686_gnu 0.53.1", + "windows_i686_gnullvm 0.53.1", + "windows_i686_msvc 0.53.1", + "windows_x86_64_gnu 0.53.1", + "windows_x86_64_gnullvm 0.53.1", + "windows_x86_64_msvc 0.53.1", ] +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + [[package]] name = "windows_aarch64_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9d8416fa8b42f5c947f8482c43e7d89e73a173cead56d044f6a56104a6d1b53" +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + [[package]] name = "windows_aarch64_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b9d782e804c2f632e395708e99a94275910eb9100b2114651e04744e9b125006" +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + [[package]] name = "windows_i686_gnu" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "960e6da069d81e09becb0ca57a65220ddff016ff2d6af6a223cf372a506593a3" +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + [[package]] name = "windows_i686_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa7359d10048f68ab8b09fa71c3daccfb0e9b559aed648a8f95469c27057180c" +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + [[package]] name = "windows_i686_msvc" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e7ac75179f18232fe9c285163565a57ef8d3c89254a30685b57d83a38d326c2" +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + [[package]] name = "windows_x86_64_gnu" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c3842cdd74a865a8066ab39c8a7a473c0778a3f29370b5fd6b4b9aa7df4a499" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + [[package]] name = "windows_x86_64_gnullvm" version = "0.53.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ffa179e2d07eee8ad8f57493436566c7cc30ac536a3379fdf008f47f6bb7ae1" +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + [[package]] name = "windows_x86_64_msvc" version = "0.53.1" @@ -2125,18 +2687,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ea879c944afe8a2b25fef16bb4ba234f47c694565e97383b36f3a878219065c" +checksum = "fd74ec98b9250adb3ca554bdde269adf631549f51d8a8f8f0a10b50f1cb298c3" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.30" +version = "0.8.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf955aa904d6040f70dc8e9384444cb1030aed272ba3cb09bbc4ab9e7c1f34f5" +checksum = "d8a8d209fdf45cf5138cbb5a506f6b52522a25afccc534d1475dad8e31105c6a" dependencies = [ "proc-macro2", "quote", @@ -2164,6 +2726,12 @@ dependencies = [ "synstructure", ] +[[package]] +name = "zeroize" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" + [[package]] name = "zerotrie" version = "0.2.3" diff --git a/Cargo.toml b/Cargo.toml index f71f5605..f6cb91dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,33 +24,35 @@ license = "Apache-2.0" readme = "README.md" repository = "https://github.com/FL03/concision.git" rust-version = "1.85.0" -version = "0.2.9" +version = "0.3.0" [workspace.dependencies] -concision = { default-features = false, path = "concision", version = "0.2.9" } - -concision-core = { default-features = false, path = "core", version = "0.2.9" } -concision-data = { default-features = false, path = "data", version = "0.2.9" } -concision-derive = { default-features = false, path = "derive", version = "0.2.9" } -concision-init = { default-features = false, path = "init", version = "0.2.9" } -concision-macros = { default-features = false, path = "macros", version = "0.2.9" } -concision-neural = { default-features = false, path = "neural", version = "0.2.9" } -concision-params = { default-features = false, path = "params", version = "0.2.9" } -concision-traits = { default-features = false, path = "traits", version = "0.2.9" } -# extras -concision-ext = { default-features = false, path = "ext", version = "0.2.9" } - +concision = { default-features = false, path = "concision", version = "0.3.0" } +# sdk +concision-core = { default-features = false, path = "core", version = "0.3.0" } +concision-data = { default-features = false, path = "data", version = "0.3.0" } +concision-derive = { default-features = false, path = "derive", version = "0.3.0" } +concision-init = { default-features = false, path = "init", version = "0.3.0" } +concision-macros = { default-features = false, path = "macros", version = "0.3.0" } +concision-neural = { default-features = false, path = "neural", version = "0.3.0" } +concision-params = { default-features = false, path = "params", version = "0.3.0" } +concision-traits = { default-features = false, path = "traits", version = "0.3.0" } +# custom models & extras +concision-ext = { default-features = false, path = "ext", version = "0.3.0" } # custom variants = { default-features = false, features = ["derive"], version = "0.0.1" } +# data structures +hashbrown = { default-features = false, version = "0.16" } # tensors & arrays ndarray = { default-features = false, version = "0.17" } ndarray-linalg = { default-features = false, version = "0.18" } ndarray-stats = "0.6" # benchmarking -criterion = { version = "0.7" } +criterion = { version = "0.8" } # concurrency & parallelism crossbeam = { default-features = false, version = "0.8" } rayon = { default-features = false, version = "1" } +rayon-core = { default-features = false, version = "1" } # data & serialization serde = { default-features = false, features = ["derive"], version = "1" } serde_derive = { default-features = false, version = "1" } @@ -65,23 +67,21 @@ rustfft = { version = "6" } # random getrandom = { default-features = false, version = "0.3" } rand = { default-features = false, version = "0.9" } +rand_core = { default-features = false, version = "0.9" } rand_distr = { default-features = false, version = "0.5" } -uuid = { default-features = false, version = "1" } # errors anyhow = { default-features = false, version = "1" } thiserror = { default-features = false, version = "2" } -# networking -reqwest = { default-features = false, version = "0.12" } # logging tracing = { default-features = false, features = ["attributes", "log"], version = "0.1" } -tracing-subscriber = { default-features = false, features = ["ansi", "env-filter", "fmt"], version = "0.3" } +tracing-subscriber = { features = ["ansi", "env-filter", "fmt"], version = "0.3" } # time chrono = { default-features = false, version = "0.4" } humantime = { version = "2" } time = { default-features = false, version = "0.3" } # macros and utilities +convert_case = { version = "0.10" } either = { version = "1" } -itertools = { version = "0.14" } lazy_static = { version = "1" } paste = { version = "1" } smart-default = "0.7" diff --git a/LICENSE b/LICENSE index c8aeb62c..dbe4ad1d 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2025 Joe McCain III + Copyright 2025 Joe McCain III Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/README.md b/README.md index 8ce9f6c2..099fbef3 100644 --- a/README.md +++ b/README.md @@ -30,11 +30,18 @@ _**Warning: The library still in development and is not yet ready for production - `Model`: A trait for defining a neural network model. - `Predict`: A trait extending the basic [`Forward`](cnc::Forward) pass. - `Train`: A trait for training a neural network model. -- [ ] **v2**: - - [ ] **Models**: - - `Trainer`: A generic model trainer that can be used to train any model. - - [ ] **Layers**: Implement a standard model configuration and parameters. - - `LayerBase`: _functional_ wrappers for the `ParamsBase` structure. +- [x] **v2**: + - [x] **`DeepModelParams`**: Extend the `ParamsBase` structure to support deep neural networks with multiple layers. + - [x] **Models**: Implement standard model configurations and parameters. + - `StandardModelConfig`: A standard configuration for neural network models. + - `ModelFeatures`: A structure to define the features of a model (e.g., number of layers, neurons per layer). + - [x] **Activation Functions**: Implement and refine various activation functions (`ReLU`, `Sigmoid`, `Tanh`, etc.) + - [x] **Loss Functions**: Implement common loss functions such as `MeanSquaredError` and `CrossEntropy` +- [ ] **v3**: + - [ ] **Optimizers**: Implement optimization algorithms like `SGD` and `Adam`. + - [ ] **Scheduler**: Learning rate schedulers to adjust the learning rate during training. + - [ ] **Layers**: Refine a more functional layer-based architecture. + - [ ] **Utilities**: Additional utilities for data preprocessing, model evaluation, and visualization ## Usage diff --git a/SECURITY.md b/SECURITY.md index f78aac3a..a298596c 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -1,14 +1,15 @@ # Security Policy + Current Version: 0.3.0 + ## Supported Versions -Checkout the current and supported packages below +Checkout the table below for information on which versions are currently supported: -| Version | Supported | -|:------------------|:-------------------| -| 0.2.9 | :white_check_mark: | -| <=0.2.0,>0.2.9 | :white_check_mark: | -| <=0.2.0 | :x: | +| Version | Supported | +|:-----------------|:-------------------| +| 0.2.9,<=0.3.0 | :white_check_mark: | +| <0.2.9 | :x: | ## Reporting a Vulnerability diff --git a/clippy.toml b/clippy.toml index e69de29b..1a78622e 100644 --- a/clippy.toml +++ b/clippy.toml @@ -0,0 +1 @@ +msrv = "1.35.0" diff --git a/concision/Cargo.toml b/concision/Cargo.toml index 9c1b0497..4a591402 100644 --- a/concision/Cargo.toml +++ b/concision/Cargo.toml @@ -1,38 +1,36 @@ [package] build = "build.rs" -name = "concision" - -authors.workspace = true -categories.workspace = true -description.workspace = true -edition.workspace = true -homepage.workspace = true -keywords.workspace = true -license.workspace = true -readme.workspace = true -repository.workspace = true +name = "concision" + +authors.workspace = true +categories.workspace = true +description.workspace = true +edition.workspace = true +homepage.workspace = true +keywords.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true rust-version.workspace = true -version.workspace = true +version.workspace = true [package.metadata.docs.rs] -all-features = false -doc-scrape-examples = true -features = ["full"] -rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" +all-features = false +features = ["default"] +rustc-args = ["--cfg", "docsrs"] [package.metadata.release] no-dev-version = true -tag-name = "{{version}}" +tag-name = "v{{version}}" [lib] -bench = true +bench = true crate-type = ["cdylib", "rlib"] -doc = true -doctest = true -name = "cnc" -path = "lib.rs" -test = true +doc = true +doctest = true +name = "concision" +path = "lib.rs" +test = true [[bench]] harness = false @@ -46,35 +44,25 @@ required-features = [ ] [[example]] -name = "params" +name = "params" required-features = ["rand", "std", "tracing"] [[example]] -name = "simple" +name = "simple" required-features = ["rand", "std", "tracing"] -[[test]] -name = "default" - -[[test]] -name = "simple" -required-features = ["rand", "std"] - [dependencies] -concision-core = { workspace = true } -concision-data = { optional = true, workspace = true } +concision-core = { workspace = true } +concision-data = { optional = true, workspace = true } concision-derive = { optional = true, workspace = true } concision-macros = { optional = true, workspace = true } [dev-dependencies] -anyhow = { features = ["std"], workspace = true } -approx = { workspace = true } -criterion = { features = ["plotters"], workspace = true } -lazy_static = { workspace = true } -ndarray = { workspace = true } -num = { features = ["rand", "serde"], workspace = true } -serde = { features = ["std"], workspace = true } -tracing = { features = ["attributes", "log", "std"], workspace = true } +anyhow = { features = ["std"], workspace = true } +approx = { workspace = true } +criterion = { features = ["plotters"], workspace = true } +ndarray = { workspace = true } +tracing = { features = ["attributes", "log", "std"], workspace = true } tracing-subscriber = { workspace = true } [features] @@ -106,8 +94,6 @@ macros = ["concision-core/macros", "dep:concision-macros"] data = ["dep:concision-data"] -neural = ["alloc"] - # ************* [FF:Environments] ************* std = [ "alloc", diff --git a/concision/benches/params.rs b/concision/benches/params.rs index 06b51887..33b14d16 100644 --- a/concision/benches/params.rs +++ b/concision/benches/params.rs @@ -2,7 +2,9 @@ appellation: params authors: @FL03 */ -use cnc::init::InitRand; +extern crate concision as cnc; + +use cnc::init::NdInit; use core::hint::black_box; use criterion::{BatchSize, BenchmarkId, Criterion}; @@ -25,15 +27,12 @@ fn bench_params_forward(c: &mut Criterion) { b.iter_batched( || { let params = cnc::Params::::glorot_normal((n, 64)); + let input = Array1::::linspace(0.0, 1.0, x); // return the configured parameters - params + (params, input) }, - |params| { - let input = Array1::::linspace(0.0, 1.0, x); - let y = params - .forward(black_box(&input)) - .expect("Forward pass failed"); - y + |(params, input)| { + params.forward(black_box(&input)); }, BatchSize::SmallInput, ); diff --git a/concision/examples/params.rs b/concision/examples/params.rs index a8d1c8fb..6dc386bc 100644 --- a/concision/examples/params.rs +++ b/concision/examples/params.rs @@ -2,7 +2,9 @@ Appellation: params Contrib: FL03 */ -use cnc::init::InitRand; +extern crate concision as cnc; + +use cnc::init::NdInit; use cnc::params::Params; use ndarray::prelude::*; @@ -31,7 +33,7 @@ fn main() -> anyhow::Result<()> { assert_eq!(params.bias().shape(), &[n]); tracing::info!("Randomized parameters: {params:?}"); - let y = params.forward(&inputs).expect("forward pass failed"); + let y = params.forward(&inputs); assert_eq!(y.shape(), &[n]); tracing::info!("Forward pass: {y:?}"); diff --git a/concision/examples/simple.rs b/concision/examples/simple.rs index b399bdac..f8a0e46b 100644 --- a/concision/examples/simple.rs +++ b/concision/examples/simple.rs @@ -1,9 +1,11 @@ /* - Appellation: linear + Appellation: simple Created At: 2025.11.26:14:10:58 Contrib: @FL03 */ -use cnc::models::ex::sample::TestModel; +extern crate concision as cnc; + +use cnc::ex::sample::TestModel; use cnc::{ModelFeatures, Predict, StandardModelConfig, Train}; use ndarray::prelude::*; @@ -30,9 +32,7 @@ fn main() -> anyhow::Result<()> { // initialize some input data let input = Array1::linspace(1.0, 9.0, model.features().input()); // propagate the input through the model - let output = model - .predict(&input) - .expect("Failed to forward the input through the model"); + let output = model.predict(&input); tracing::info!("output: {:?}", output); // verify the output shape assert_eq!(output.dim(), (model.features().output())); @@ -44,9 +44,7 @@ fn main() -> anyhow::Result<()> { model.train(&training_input, &expected_output)?; } // forward the input through the model - let output = model - .predict(&input) - .expect("Failed to forward the input through the model"); + let output = model.predict(&input); tracing::info!("output: {:?}", output); Ok(()) diff --git a/core/Cargo.toml b/core/Cargo.toml index 0104084a..3bdf8da4 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -18,21 +18,18 @@ version.workspace = true all-features = false features = ["full"] rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" [package.metadata.release] no-dev-version = true -tag-name = "{{version}}" +tag-name = "v{{version}}" [lib] -crate-type = ["cdylib", "rlib"] bench = false -doc = true +crate-type = ["cdylib", "rlib"] doctest = true test = true [dependencies] -# local concision-init = { workspace = true } concision-params = { workspace = true } concision-traits = { workspace = true } @@ -41,6 +38,7 @@ variants = { workspace = true } # concurrency & parallelism rayon = { optional = true, workspace = true } # data structures +hashbrown = { workspace = true } ndarray = { workspace = true } # serialization serde = { features = ["derive"], optional = true, workspace = true } @@ -61,10 +59,6 @@ num-complex = { optional = true, workspace = true } num-integer = { workspace = true } num-traits = { workspace = true } rustfft = { optional = true, workspace = true } -# random -getrandom = { default-features = false, optional = true, workspace = true } -rand = { optional = true, workspace = true } -rand_distr = { optional = true, workspace = true } [dev-dependencies] lazy_static = { workspace = true } @@ -73,9 +67,9 @@ lazy_static = { workspace = true } default = ["macros", "std"] full = [ - "default", "approx", "complex", + "default", "json", "rand", "serde", @@ -87,6 +81,7 @@ nightly = [ "concision-init/nightly", "concision-params/nightly", "concision-traits/nightly", + "hashbrown/nightly", ] # ************* [FF:Features] ************* @@ -105,12 +100,11 @@ std = [ "concision-init/std", "concision-params/std", "concision-traits/std", + "hashbrown/default", "ndarray/std", - "num-integer/std", "num-complex?/std", + "num-integer/std", "num-traits/std", - "rand?/std", - "rand?/std_rng", "serde/std", "strum/std", "thiserror/std", @@ -125,7 +119,6 @@ wasi = [ ] wasm = [ - "getrandom?/wasm_js", "concision-init/wasm", "concision-params/wasm", "concision-traits/wasm", @@ -136,6 +129,7 @@ alloc = [ "concision-init/alloc", "concision-params/alloc", "concision-traits/alloc", + "hashbrown/alloc", "serde?/alloc", "serde_json?/alloc", "variants/alloc", @@ -158,31 +152,27 @@ complex = [ "dep:num-complex", "concision-init/complex", "concision-params/complex", + "concision-traits/complex", ] rand = [ - "dep:rand", - "dep:rand_distr", "concision-init/rand", "concision-params/rand", "concision-traits/rand", - "num-complex?/rand", - "rng", ] rayon = [ "dep:rayon", "concision-params/rayon", + "concision-traits/rayon", + "hashbrown/rayon", "ndarray/rayon", ] rng = [ - "dep:getrandom", "concision-init/rng", "concision-params/rng", - "concision-traits/rng", - "rand?/small_rng", - "rand?/thread_rng", + "concision-traits/rand", ] rustfft = ["dep:rustfft"] @@ -192,27 +182,11 @@ serde = [ "dep:serde_derive", "concision-init/serde", "concision-params/serde", + "hashbrown/serde", "ndarray/serde", "num-complex?/serde", - "rand?/serde", - "rand_distr?/serde", ] serde_json = ["dep:serde_json"] -tracing = [ - "concision-init/tracing", - "dep:tracing", -] - -# ************* [Unit Tests] ************* -[[test]] -name = "params" -required-features = ["approx", "std"] - -[[test]] -name = "fft" -required-features = ["alloc", "signal"] - -[[test]] -name = "utils" \ No newline at end of file +tracing = ["concision-init/tracing", "dep:tracing"] diff --git a/core/src/activate/impls/impl_linear.rs b/core/src/activate/impls/impl_linear.rs index 29c4b42c..d26f5370 100644 --- a/core/src/activate/impls/impl_linear.rs +++ b/core/src/activate/impls/impl_linear.rs @@ -3,7 +3,7 @@ Contrib: FL03 */ -impl crate::activate::LinearActivation for &T +impl crate::activate::LinearActivation for T where T: Clone + Default, { diff --git a/core/src/activate/mod.rs b/core/src/activate/mod.rs index 3e7872a9..9c463931 100644 --- a/core/src/activate/mod.rs +++ b/core/src/activate/mod.rs @@ -12,15 +12,17 @@ //! manifesting in a number of traits, utilities, and other primitives used to define various //! approaches to activation functions. //! -//! - [Heavyside] -//! - [LinearActivation] -//! - [Sigmoid] -//! - [Softmax] -//! - [ReLU] -//! - [Tanh] +//! - [`HeavysideActivation`] +//! - [`LinearActivation`] +//! - [`SigmoidActivation`] +//! - [`SoftmaxActivation`] +//! - [`ReLUActivation`] +//! - [`TanhActivation`] //! #[doc(inline)] -pub use self::prelude::*; +pub use self::{traits::*, utils::*}; + +pub(crate) mod utils; pub(crate) mod traits { #[doc(inline)] @@ -37,21 +39,6 @@ pub(crate) mod traits { } } -pub(crate) mod utils { - #[doc(inline)] - pub use self::prelude::*; - - mod non_linear; - mod simple; - - mod prelude { - #[doc(inline)] - pub use super::non_linear::*; - #[doc(inline)] - pub use super::simple::*; - } -} - mod impls { mod impl_binary; mod impl_linear; @@ -59,8 +46,6 @@ mod impls { } pub(crate) mod prelude { - #[doc(inline)] pub use super::traits::*; - #[doc(inline)] pub use super::utils::*; } diff --git a/core/src/activate/traits/activate.rs b/core/src/activate/traits/activate.rs index 8e671253..f91761b4 100644 --- a/core/src/activate/traits/activate.rs +++ b/core/src/activate/traits/activate.rs @@ -15,74 +15,79 @@ use num_traits::One; /// The trait is generic over a type `U`, which represents the data type of the input to the /// activation functions. The trait also inherits a type alias `Cont` to allow for variance /// w.r.t. the outputs of defined methods. -pub trait Rho: Apply { +pub trait Rho { + type Cont<_V>; + + fn rho(&self, f: F) -> Self::Cont + where + F: Fn(T) -> V; /// the linear activation function is essentially a passthrough function, simply cloning /// the content. - fn linear(&self) -> Self::Cont { - self.apply(|x| x) + fn linear(&self) -> Self::Cont { + self.rho(|x| x) } - fn linear_derivative(&self) -> Self::Cont + fn linear_derivative(&self) -> Self::Cont where - U: One, + T: One, { - self.apply(|_| ::one()) + self.rho(|_| ::one()) } - fn heavyside(&self) -> Self::Cont + fn heavyside(&self) -> Self::Cont where - U: HeavysideActivation, + T: HeavysideActivation, { - self.apply(|x| x.heavyside()) + self.rho(|x| x.heavyside()) } - fn heavyside_derivative(&self) -> Self::Cont + fn heavyside_derivative(&self) -> Self::Cont where - U: HeavysideActivation, + T: HeavysideActivation, { - self.apply(|x| x.heavyside_derivative()) + self.rho(|x| x.heavyside_derivative()) } - fn relu(&self) -> Self::Cont + fn relu(&self) -> Self::Cont where - U: ReLUActivation, + T: ReLUActivation, { - self.apply(|x| x.relu()) + self.rho(|x| x.relu()) } - fn relu_derivative(&self) -> Self::Cont + fn relu_derivative(&self) -> Self::Cont where - U: ReLUActivation, + T: ReLUActivation, { - self.apply(|x| x.relu_derivative()) + self.rho(|x| x.relu_derivative()) } - fn sigmoid(&self) -> Self::Cont + fn sigmoid(&self) -> Self::Cont where - U: SigmoidActivation, + T: SigmoidActivation, { - self.apply(|x| x.sigmoid()) + self.rho(|x| x.sigmoid()) } - fn sigmoid_derivative(&self) -> Self::Cont + fn sigmoid_derivative(&self) -> Self::Cont where - U: SigmoidActivation, + T: SigmoidActivation, { - self.apply(|x| x.sigmoid_derivative()) + self.rho(|x| x.sigmoid_derivative()) } - fn tanh(&self) -> Self::Cont + fn tanh(&self) -> Self::Cont where - U: TanhActivation, + T: TanhActivation, { - self.apply(|x| x.tanh()) + self.rho(|x| x.tanh()) } - fn tanh_derivative(&self) -> Self::Cont + fn tanh_derivative(&self) -> Self::Cont where - U: TanhActivation, + T: TanhActivation, { - self.apply(|x| x.tanh_derivative()) + self.rho(|x| x.tanh_derivative()) } } @@ -99,27 +104,27 @@ pub trait Rho: Apply { /// to use a fully qualified syntax to disambiguate the two traits. If this becomes a problem, /// we may consider renaming the _complex_ methods accordingly to differentiate them from the /// _standard_ methods. -pub trait RhoComplex: Apply +pub trait RhoComplex: Rho where U: ComplexFloat, { fn sigmoid(&self) -> Self::Cont { - self.apply(|x| U::one() / (U::one() + (-x).exp())) + self.rho(|x| U::one() / (U::one() + (-x).exp())) } fn sigmoid_derivative(&self) -> Self::Cont { - self.apply(|x| { + self.rho(|x| { let s = U::one() / (U::one() + (-x).exp()); s * (U::one() - s) }) } fn tanh(&self) -> Self::Cont { - self.apply(|x| x.tanh()) + self.rho(|x| x.tanh()) } fn tanh_derivative(&self) -> Self::Cont { - self.apply(|x| { + self.rho(|x| { let s = x.tanh(); U::one() - s * s }) @@ -129,7 +134,19 @@ where /* ************* Implementations ************* */ -impl Rho for S where S: Apply {} +impl Rho for S +where + S: Apply, +{ + type Cont<_V> = S::Cont<_V>; + + fn rho(&self, f: F) -> Self::Cont + where + F: Fn(T) -> V, + { + self.apply(|x| f(x)) + } +} #[cfg(feature = "complex")] impl RhoComplex for S diff --git a/core/src/activate/utils/non_linear.rs b/core/src/activate/utils.rs similarity index 69% rename from core/src/activate/utils/non_linear.rs rename to core/src/activate/utils.rs index 61d218ac..dc78948e 100644 --- a/core/src/activate/utils/non_linear.rs +++ b/core/src/activate/utils.rs @@ -8,7 +8,7 @@ use num_traits::{Float, One, Zero}; /// the relu activation function: /// /// ```math -/// f(x) = \max(0, x) +/// \mbox{f}(x) = \max(0, x) /// ``` pub fn relu(args: T) -> T where @@ -17,6 +17,10 @@ where if args > T::zero() { args } else { T::zero() } } +/// +/// ```math +/// \frac{df}{dx}=\max(0,1) +/// ``` pub fn relu_derivative(args: T) -> T where T: PartialOrd + One + Zero, @@ -30,7 +34,7 @@ where /// the sigmoid activation function: /// /// ```math -/// f(x) = \frac{1}{1 + e^{-x}} +/// \mbox{f}(x)=\frac{1}{1+\exp(-x)} /// ``` pub fn sigmoid(args: T) -> T where @@ -94,3 +98,36 @@ where let t = tanh(args); T::one() - t * t } + +/// the [`linear`] method is essentially a _passthrough_ method often used in simple models +/// or layers where no activation is needed. +pub const fn linear(x: T) -> T { + x +} + +/// the [`linear_derivative`] method always returns `1` as it is a simple, single variable +/// function +pub fn linear_derivative() -> T +where + T: One, +{ + ::one() +} + +/// Heaviside activation function: +/// +/// ```math +/// H(x) = +/// \left\{ +/// \begin{array}{rcl} +/// 1 & \mbox{if} & x\gt{0} \\ +/// 0 & \mbox{if} & x\leq{0} +/// \end{array} +/// \right. +/// ``` +pub fn heavyside(x: T) -> T +where + T: One + PartialOrd + Zero, +{ + if x > T::zero() { T::one() } else { T::zero() } +} diff --git a/core/src/activate/utils/simple.rs b/core/src/activate/utils/simple.rs deleted file mode 100644 index fb5f04a2..00000000 --- a/core/src/activate/utils/simple.rs +++ /dev/null @@ -1,38 +0,0 @@ -/* - Appellation: utils - Contrib: FL03 -*/ -use num_traits::{One, Zero}; - -/// the [`linear`] method is essentially a _passthrough_ method often used in simple models -/// or layers where no activation is needed. -pub const fn linear(x: T) -> T { - x -} - -/// the [`linear_derivative`] method always returns `1` as it is a simple, single variable -/// function -pub fn linear_derivative() -> T -where - T: One, -{ - ::one() -} - -/// Heaviside activation function: -/// -/// ```math -/// H(x) = -/// \left\{ -/// \begin{array}{rcl} -/// 1 & \mbox{if} & x\gt{0} \\ -/// 0 & \mbox{if} & x\leq{0} -/// \end{array} -/// \right. -/// ``` -pub fn heavyside(x: T) -> T -where - T: One + PartialOrd + Zero, -{ - if x > T::zero() { T::one() } else { T::zero() } -} diff --git a/core/src/config/hyper_params.rs b/core/src/config/hyper_params.rs new file mode 100644 index 00000000..a4d490a8 --- /dev/null +++ b/core/src/config/hyper_params.rs @@ -0,0 +1,136 @@ +/* + Appellation: hyper_params + Contrib: @FL03 +*/ +#[cfg(feature = "alloc")] +use alloc::string::{String, ToString}; + +/// An enumeration of common HyperParams used in neural network configurations. +#[derive( + Clone, + Debug, + Eq, + Hash, + Ord, + PartialEq, + PartialOrd, + strum::EnumCount, + strum::EnumIs, + strum::EnumIter, + strum::VariantNames, +)] +#[cfg_attr( + feature = "serde", + derive(serde::Deserialize, serde::Serialize), + serde(rename_all = "snake_case", untagged) +)] +pub enum HyperParam { + Decay, + #[serde(alias = "drop_out", alias = "p")] + Dropout, + #[serde(alias = "lr", alias = "gamma")] + LearningRate, + Momentum, + Temperature, + WeightDecay, + Beta1, + Beta2, + Epsilon, + #[cfg(feature = "alloc")] + Custom(String), +} + +impl HyperParam { + #[cfg(feature = "alloc")] + /// creates a custom hyperparameter variant + pub fn custom(name: T) -> Self { + HyperParam::Custom(name.to_string()) + } + /// returns a list of variants as strings + pub const fn variants() -> &'static [&'static str] { + use strum::VariantNames; + HyperParam::VARIANTS + } +} + +impl core::fmt::Display for HyperParam { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str(self.as_ref()) + } +} + +impl AsRef for HyperParam { + fn as_ref(&self) -> &str { + match self { + HyperParam::Decay => "decay", + HyperParam::Dropout => "dropout", + HyperParam::LearningRate => "learning_rate", + HyperParam::Momentum => "momentum", + HyperParam::Temperature => "temperature", + HyperParam::WeightDecay => "weight_decay", + HyperParam::Beta1 => "beta1", + HyperParam::Beta2 => "beta2", + HyperParam::Epsilon => "epsilon", + HyperParam::Custom(s) => s.as_ref(), + } + } +} + +impl core::borrow::Borrow for HyperParam { + fn borrow(&self) -> &str { + self.as_ref() + } +} + +#[cfg(feature = "alloc")] +impl core::convert::From for HyperParam { + fn from(s: String) -> Self { + core::str::FromStr::from_str(&s).expect("Failed to convert String to HyperParams") + } +} + +impl From<&str> for HyperParam { + fn from(s: &str) -> Self { + core::str::FromStr::from_str(s).expect("Failed to convert &str to HyperParams") + } +} + +impl core::str::FromStr for HyperParam { + type Err = core::convert::Infallible; + + fn from_str(s: &str) -> Result { + match s { + "decay" => Ok(HyperParam::Decay), + "dropout" => Ok(HyperParam::Dropout), + "learning_rate" => Ok(HyperParam::LearningRate), + "momentum" => Ok(HyperParam::Momentum), + "temperature" => Ok(HyperParam::Temperature), + "weight_decay" => Ok(HyperParam::WeightDecay), + #[cfg(feature = "alloc")] + other => Ok(HyperParam::Custom(other.to_string())), + } + } +} + +#[cfg(test)] +mod tests { + use super::HyperParam; + + #[test] + fn test_hyper_params() { + use HyperParam::*; + + assert_eq!(HyperParam::from("learning_rate"), LearningRate); + + assert_eq!(HyperParam::from("weight_decay"), WeightDecay); + + for v in ["something", "another_param", "custom_hyperparam"] { + let param = HyperParam::from(v); + assert!(param.is_custom()); + match param { + HyperParam::Custom(s) => assert_eq!(s, v), + _ => panic!("Expected Custom variant"), + } + } + } +} diff --git a/core/src/config/mod.rs b/core/src/config/mod.rs index 035aa9a5..a13bf3d9 100644 --- a/core/src/config/mod.rs +++ b/core/src/config/mod.rs @@ -5,29 +5,79 @@ //! This module is dedicated to establishing common interfaces for valid configuration objects //! while providing a standard implementation to quickly spin up a new model. #[doc(inline)] -pub use self::{model_config::StandardModelConfig, traits::*, types::*}; +pub use self::{hyper_params::HyperParam, model_config::StandardModelConfig}; +pub mod hyper_params; pub mod model_config; +pub(crate) mod prelude { + pub use super::hyper_params::HyperParam; + pub use super::model_config::*; + pub use super::{ExtendedModelConfig, ModelConfiguration, RawConfig}; +} -mod traits { - #[doc(inline)] - pub use self::config::*; - - mod config; +/// The [`RawConfig`] trait defines a basic interface for all _configurations_ used within the +/// framework for neural networks, their layers, and more. +pub trait RawConfig { + type Ctx; } -mod types { - #[doc(inline)] - pub use self::{hyper_params::*, key_value::*}; +/// The [`ModelConfiguration`] trait extends the [`RawConfig`] trait to provide a more robust +/// interface for neural network configurations. +pub trait ModelConfiguration: RawConfig { + fn get(&self, key: K) -> Option<&T> + where + K: AsRef; + fn get_mut(&mut self, key: K) -> Option<&mut T> + where + K: AsRef; + + fn set(&mut self, key: K, value: T) -> Option + where + K: AsRef; + fn remove(&mut self, key: K) -> Option + where + K: AsRef; + fn contains(&self, key: K) -> bool + where + K: AsRef; - mod hyper_params; - mod key_value; + fn keys(&self) -> Vec; } -pub(crate) mod prelude { - pub use super::model_config::*; - pub use super::traits::*; - pub use super::types::*; +macro_rules! hyperparam_method { + ($($(dyn)? $name:ident::<$type:ty>),* $(,)?) => { + $( + hyperparam_method!(@impl $name::<$type>); + )* + }; + (@impl dyn $name:ident::<$type:ty>) => { + fn $name(&self) -> Option<&$type> where T: 'static { + self.get(stringify!($name)).map(|v| v.downcast_ref::<$type>()).flatten() + } + }; + (@impl $name:ident::<$type:ty>) => { + fn $name(&self) -> Option<&$type> { + self.get(stringify!($name)) + } + }; +} + +pub trait ExtendedModelConfig: ModelConfiguration { + fn epochs(&self) -> usize; + + fn batch_size(&self) -> usize; + + hyperparam_method! { + learning_rate::, + epsilon::, + momentum::, + weight_decay::, + dropout::, + decay::, + beta::, + beta1::, + beta2::, + } } #[cfg(test)] diff --git a/core/src/config/model_config.rs b/core/src/config/model_config.rs index 1c4d4aa9..acfc28e2 100644 --- a/core/src/config/model_config.rs +++ b/core/src/config/model_config.rs @@ -2,26 +2,21 @@ Appellation: config Contrib: @FL03 */ -use super::Hyperparameters::*; +use super::HyperParam; use super::{ExtendedModelConfig, ModelConfiguration, RawConfig}; - -#[cfg(all(feature = "alloc", not(feature = "std")))] -pub(crate) type ModelConfigMap = alloc::collections::BTreeMap; -#[cfg(feature = "std")] -pub(crate) type ModelConfigMap = std::collections::HashMap; +use hashbrown::DefaultHashBuilder; +use hashbrown::hash_map::{self, HashMap}; #[derive(Clone, Debug)] -#[cfg_attr(feature = "serde", derive(serde_derive::Deserialize, serde::Serialize))] +#[cfg_attr( + feature = "serde", + derive(serde::Deserialize, serde::Serialize), + serde(rename = "snake_case") +)] pub struct StandardModelConfig { - pub(crate) batch_size: usize, - pub(crate) epochs: usize, - pub(crate) hyperparameters: ModelConfigMap, -} - -impl Default for StandardModelConfig { - fn default() -> Self { - Self::new() - } + pub batch_size: usize, + pub epochs: usize, + pub hyperspace: HashMap, } impl StandardModelConfig { @@ -29,7 +24,7 @@ impl StandardModelConfig { Self { batch_size: 0, epochs: 0, - hyperparameters: ModelConfigMap::new(), + hyperspace: HashMap::new(), } } /// returns a copy of the batch size @@ -49,35 +44,39 @@ impl StandardModelConfig { &mut self.epochs } /// returns a reference to the hyperparameters map - pub const fn hyperparameters(&self) -> &ModelConfigMap { - &self.hyperparameters + pub const fn hyperparameters(&self) -> &HashMap { + &self.hyperspace } /// returns a mutable reference to the hyperparameters map - pub const fn hyperparameters_mut(&mut self) -> &mut ModelConfigMap { - &mut self.hyperparameters + pub const fn hyperparameters_mut(&mut self) -> &mut HashMap { + &mut self.hyperspace } /// inserts a hyperparameter into the map, returning the previous value if it exists - pub fn add_parameter(&mut self, key: impl ToString, value: T) -> Option { - self.hyperparameters_mut().insert(key.to_string(), value) + pub fn add_parameter>(&mut self, key: P, value: T) -> Option { + self.hyperparameters_mut().insert(key.into(), value) } /// gets a reference to a hyperparameter by key, returning None if it does not exist pub fn get_parameter(&self, key: &Q) -> Option<&T> where Q: ?Sized + Eq + core::hash::Hash, - String: core::borrow::Borrow, + HyperParam: core::borrow::Borrow, { self.hyperparameters().get(key) } /// returns an entry for the hyperparameter, allowing for insertion or modification - pub fn parameter(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T> + pub fn parameter(&mut self, key: Q) -> hash_map::Entry<'_, HyperParam, T, DefaultHashBuilder> where - Q: ToString, + Q: AsRef, { - self.hyperparameters_mut().entry(key.to_string()) + self.hyperparameters_mut().entry(key.as_ref().into()) } /// removes a hyperparameter from the map, returning the value if it exists - pub fn remove_hyperparameter(&mut self, key: impl ToString) -> Option { - self.hyperparameters_mut().remove(&key.to_string()) + pub fn remove_hyperparameter(&mut self, key: &Q) -> Option + where + Q: ?Sized + core::hash::Hash + Eq, + HyperParam: core::borrow::Borrow, + { + self.hyperparameters_mut().remove(key) } /// sets the batch size, returning a mutable reference to the current instance pub fn set_batch_size(&mut self, batch_size: usize) -> &mut Self { @@ -97,6 +96,11 @@ impl StandardModelConfig { pub fn with_epochs(self, epochs: usize) -> Self { Self { epochs, ..self } } +} + +use HyperParam::*; + +impl StandardModelConfig { /// sets the decay hyperparameter, returning the previous value if it exists pub fn set_decay(&mut self, decay: T) -> Option { self.add_parameter(Decay, decay) @@ -110,23 +114,29 @@ impl StandardModelConfig { } /// sets the weight decay hyperparameter, returning the previous value if it exists pub fn set_weight_decay(&mut self, decay: T) -> Option { - self.add_parameter("weight_decay", decay) + self.add_parameter(WeightDecay, decay) } /// returns a reference to the learning rate hyperparameter, if it exists pub fn learning_rate(&self) -> Option<&T> { - self.get_parameter(LearningRate.as_ref()) + self.get_parameter(&LearningRate) } /// returns a reference to the momentum hyperparameter, if it exists pub fn momentum(&self) -> Option<&T> { - self.get_parameter(Momentum.as_ref()) + self.get_parameter(&Momentum) } /// returns a reference to the decay hyperparameter, if it exists pub fn decay(&self) -> Option<&T> { - self.get_parameter(Decay.as_ref()) + self.get_parameter(&Decay) } /// returns a reference to the weight decay hyperparameter, if it exists pub fn weight_decay(&self) -> Option<&T> { - self.get_parameter("weight_decay") + self.get_parameter(&WeightDecay) + } +} + +impl Default for StandardModelConfig { + fn default() -> Self { + Self::new() } } @@ -158,7 +168,7 @@ impl ModelConfiguration for StandardModelConfig { K: AsRef, { self.hyperparameters_mut() - .insert(key.as_ref().to_string(), value) + .insert(key.as_ref().into(), value) } fn remove(&mut self, key: K) -> Option @@ -175,7 +185,7 @@ impl ModelConfiguration for StandardModelConfig { self.hyperparameters().contains_key(key.as_ref()) } - fn keys(&self) -> Vec { + fn keys(&self) -> Vec { self.hyperparameters().keys().cloned().collect() } } @@ -189,21 +199,3 @@ impl ExtendedModelConfig for StandardModelConfig { self.batch_size } } -#[allow(deprecated)] -impl StandardModelConfig { - #[deprecated(since = "0.1.0", note = "Use `add_parameter` instead.")] - pub fn insert_parameter(&mut self, key: impl ToString, value: T) -> Option { - self.add_parameter(key, value) - } - #[deprecated(since = "0.1.0", note = "Use `parameter` instead.")] - pub fn hyperparam(&mut self, key: Q) -> std::collections::hash_map::Entry<'_, String, T> - where - Q: ToString, - { - self.parameter(key) - } - #[deprecated(since = "0.1.0", note = "Use `get_parameter` instead.")] - pub fn get(&self, key: impl ToString) -> Option<&T> { - self.hyperparameters().get(&key.to_string()) - } -} diff --git a/core/src/config/traits/config.rs b/core/src/config/traits/config.rs deleted file mode 100644 index bec53ffb..00000000 --- a/core/src/config/traits/config.rs +++ /dev/null @@ -1,69 +0,0 @@ -/* - Appellation: config - Contrib: @FL03 -*/ - -/// The [`RawConfig`] trait defines a basic interface for all _configurations_ used within the -/// framework for neural networks, their layers, and more. -pub trait RawConfig { - type Ctx; -} - -/// The [`ModelConfiguration`] trait extends the [`RawConfig`] trait to provide a more robust -/// interface for neural network configurations. -pub trait ModelConfiguration: RawConfig { - fn get(&self, key: K) -> Option<&T> - where - K: AsRef; - fn get_mut(&mut self, key: K) -> Option<&mut T> - where - K: AsRef; - - fn set(&mut self, key: K, value: T) -> Option - where - K: AsRef; - fn remove(&mut self, key: K) -> Option - where - K: AsRef; - fn contains(&self, key: K) -> bool - where - K: AsRef; - - fn keys(&self) -> Vec; -} - -macro_rules! hyperparam_method { - ($($(dyn)? $name:ident::<$type:ty>),* $(,)?) => { - $( - hyperparam_method!(@impl $name::<$type>); - )* - }; - (@impl dyn $name:ident::<$type:ty>) => { - fn $name(&self) -> Option<&$type> where T: 'static { - self.get(stringify!($name)).map(|v| v.downcast_ref::<$type>()).flatten() - } - }; - (@impl $name:ident::<$type:ty>) => { - fn $name(&self) -> Option<&$type> { - self.get(stringify!($name)) - } - }; -} - -pub trait ExtendedModelConfig: ModelConfiguration { - fn epochs(&self) -> usize; - - fn batch_size(&self) -> usize; - - hyperparam_method! { - learning_rate::, - epsilon::, - momentum::, - weight_decay::, - dropout::, - decay::, - beta::, - beta1::, - beta2::, - } -} diff --git a/core/src/config/types/hyper_params.rs b/core/src/config/types/hyper_params.rs deleted file mode 100644 index b64dd3b4..00000000 --- a/core/src/config/types/hyper_params.rs +++ /dev/null @@ -1,98 +0,0 @@ -/* - Appellation: hyperparameters - Contrib: @FL03 -*/ -use super::KeyValue; - -#[cfg(feature = "alloc")] -use alloc::string::String; - -/// An enumeration of common hyperparameters used in neural network configurations. -#[derive( - Clone, - Debug, - Eq, - Hash, - Ord, - PartialEq, - PartialOrd, - strum::EnumCount, - strum::EnumIs, - strum::EnumDiscriminants, -)] -#[cfg_attr( - feature = "serde", - derive(serde::Deserialize, serde::Serialize), - serde(rename_all = "snake_case", untagged), - strum_discriminants(derive(serde::Deserialize, serde::Serialize)) -)] -#[strum_discriminants( - name(Hyperparameters), - derive( - Hash, - Ord, - PartialOrd, - strum::AsRefStr, - strum::Display, - strum::EnumCount, - strum::EnumIs, - strum::EnumIter, - strum::EnumString, - strum::VariantArray, - strum::VariantNames, - variants::VariantConstructors, - ), - strum(serialize_all = "snake_case") -)] -#[strum(serialize_all = "snake_case")] -pub enum HyperParams { - Decay(T), - Dropout(T), - LearningRate(T), - Momentum(T), - Temperature(T), - WeightDecay(T), - Custom { key: String, value: T }, -} - -impl From> for HyperParams -where - K: AsRef, -{ - fn from(KeyValue { key, value }: KeyValue) -> Self { - match key.as_ref() { - "decay" => HyperParams::Decay(value), - "dropout" => HyperParams::Dropout(value), - "learning_rate" => HyperParams::LearningRate(value), - "momentum" => HyperParams::Momentum(value), - "temperature" => HyperParams::Temperature(value), - "weight_decay" => HyperParams::WeightDecay(value), - k => HyperParams::Custom { - key: String::from(k), - value, - }, - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use core::str::FromStr; - - #[test] - fn test_hyper() { - use strum::IntoEnumIterator; - - assert_eq!( - Hyperparameters::from_str("learning_rate"), - Ok(Hyperparameters::LearningRate) - ); - - for variant in Hyperparameters::iter() { - let name = variant.as_ref(); - let parsed = Hyperparameters::from_str(name); - assert_eq!(parsed, Ok(variant)); - } - } -} diff --git a/core/src/config/types/key_value.rs b/core/src/config/types/key_value.rs deleted file mode 100644 index cc85c73b..00000000 --- a/core/src/config/types/key_value.rs +++ /dev/null @@ -1,203 +0,0 @@ -/* - appellation: key_value - authors: @FL03 -*/ - -/// The [`KeyValue`] type is used to generically represent a simple key-value pair within a -/// store. -#[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub struct KeyValue { - pub key: K, - pub value: V, -} - -impl KeyValue { - pub const fn new(key: K, value: V) -> Self { - Self { key, value } - } - /// returns a new [`KeyValue`] from the given key, using the logical default for the value - pub fn from_key(key: K) -> Self - where - V: Default, - { - Self { - key, - value: V::default(), - } - } - /// returns a new [`KeyValue`] from the given value, using the logical default for the key - pub fn from_value(value: V) -> Self - where - K: Default, - { - Self { - key: K::default(), - value, - } - } - /// returns an immutable reference to the key - pub const fn key(&self) -> &K { - &self.key - } - /// returns a mutable reference to the key - pub const fn key_mut(&mut self) -> &mut K { - &mut self.key - } - /// returns an immutable reference to the value - pub const fn value(&self) -> &V { - &self.value - } - /// returns a mutable reference to the value - pub const fn value_mut(&mut self) -> &mut V { - &mut self.value - } - /// update the current key and return a mutable reference to self - pub fn set_key(&mut self, key: K) -> &mut Self { - self.key = key; - self - } - /// update the current value and return a mutable reference to self - pub fn set_value(&mut self, value: V) -> &mut Self { - self.value = value; - self - } - /// consumes the current instance to create another with the given key - pub fn with_key(self, key: K2) -> KeyValue { - KeyValue { - key, - value: self.value, - } - } - /// consumes the current instance to create another with the given value - pub fn with_value(self, value: V2) -> KeyValue { - KeyValue { - key: self.key, - value, - } - } - /// [`replace`](core::mem::replace) the current value and return the old value - pub const fn replace_value(&mut self, value: V) -> V { - core::mem::replace(self.value_mut(), value) - } - /// [`swap`](core::mem::swap) the current value with another in the given instance - pub const fn swap_value(&mut self, other: &mut KeyValue) { - core::mem::swap(self.value_mut(), other.value_mut()) - } - /// [`take`](core::mem::take) the current value and return it, replacing it with the - /// logical default - pub fn take_value(&mut self) -> V - where - V: Default, - { - core::mem::take(self.value_mut()) - } - /// returns a new instance of the [`KeyValue`] with mutable references to the value and a - /// reference to the key - pub fn entry(&mut self) -> KeyValue<&K, &mut V> { - KeyValue { - key: &self.key, - value: &mut self.value, - } - } - /// returns a new instance of the [`KeyValue`] with references to the key and value - pub const fn view(&self) -> KeyValue<&K, &V> { - KeyValue { - key: self.key(), - value: self.value(), - } - } - /// returns a new instance of the [`KeyValue`] with mutable references to the current key - /// and value - pub const fn view_mut(&mut self) -> KeyValue<&mut K, &mut V> { - KeyValue { - key: &mut self.key, - value: &mut self.value, - } - } -} - -impl KeyValue<&K, &V> { - /// returns a new [`KeyValue`] instance with clones of the current key and value - pub fn cloned(&self) -> KeyValue - where - K: Clone, - V: Clone, - { - KeyValue { - key: self.key.clone(), - value: self.value.clone(), - } - } - /// returns a new [`KeyValue`] instance with copies of the current key and value - pub fn copied(&self) -> KeyValue - where - K: Copy, - V: Copy, - { - KeyValue { - key: *self.key, - value: *self.value, - } - } -} - -impl KeyValue<&K, &mut V> { - /// returns a new [`KeyValue`] instance with clones of the current key and value - pub fn cloned(&self) -> KeyValue - where - K: Clone, - V: Clone, - { - KeyValue { - key: self.key.clone(), - value: self.value.clone(), - } - } - /// returns a new [`KeyValue`] instance with copies of the current key and value - pub fn copied(&self) -> KeyValue - where - K: Copy, - V: Copy, - { - KeyValue { - key: *self.key, - value: *self.value, - } - } -} - -impl KeyValue<&mut K, &mut V> { - /// returns a new [`KeyValue`] instance with clones of the current key and value - pub fn cloned(&self) -> KeyValue - where - K: Clone, - V: Clone, - { - KeyValue { - key: self.key.clone(), - value: self.value.clone(), - } - } - /// returns a new [`KeyValue`] instance with copies of the current key and value - pub fn copied(&self) -> KeyValue - where - K: Copy, - V: Copy, - { - KeyValue { - key: *self.key, - value: *self.value, - } - } -} - -impl core::fmt::Display for KeyValue -where - K: core::fmt::Display, - V: core::fmt::Display, -{ - fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { - write!(f, "{k}: {v}", k = self.key(), v = self.value()) - } -} diff --git a/core/src/error.rs b/core/src/error.rs index 5388d629..56386a40 100644 --- a/core/src/error.rs +++ b/core/src/error.rs @@ -8,50 +8,47 @@ /// a type alias for a [`Result`](core::result::Result) defined to use the custom [`Error`] as its error type. pub type Result = core::result::Result; +#[cfg(feature = "alloc")] +use alloc::{boxed::Box, string::String}; /// The [`Error`] type enumerates various errors that can occur within the framework. #[derive(Debug, thiserror::Error)] #[non_exhaustive] pub enum Error { - #[cfg(feature = "alloc")] - #[error(transparent)] - AllocError(#[from] alloc_err::AllocError), - #[error(transparent)] - ExtError(#[from] CommonError), - #[error("The model is not trained")] - NotTrained, #[error("Invalid model configuration")] InvalidModelConfig, #[error("The model is not supported for the given input")] IncompatibleInput, - #[error("An unsupported operation was attempted")] - UnsupportedOperation, - #[error("Invalid Batch Size")] - InvalidBatchSize, - #[error("Invalid Input Shape")] - InvalidInputShape, - #[error("Invalid Output Shape")] - InvalidOutputShape, - #[error("Uninitialized")] + #[error("An invalid batch size was provided: {0}")] + InvalidBatchSize(usize), + #[error("Input is incompatible with the model: found {0} and expected {1}")] + InvalidInputFeatures(usize, usize), + #[error("The provided dataset has invalid target features: found {0} and expected {1}")] + InvalidTargetFeatures(usize, usize), + #[error("An uninitialized object was used")] Uninitialized, + #[error("The model is not trained")] + Untrained, + #[cfg(feature = "alloc")] #[error("Unsupported model {0}")] - UnsupportedModel(&'static str), + UnsupportedModel(String), + #[cfg(feature = "alloc")] + #[error("An unsupported operation was attempted: {0}")] + UnsupportedOperation(String), #[error("Parameter Error")] - ParameterError(&'static str), -} - -/// The [`CommonError`] type enumerates external errors handled by the framework -#[derive(Debug, thiserror::Error)] -pub enum CommonError { + ParameterError(String), #[error(transparent)] - PadError(#[from] crate::utils::pad::PadError), + AnyError(#[from] anyhow::Error), + #[cfg(feature = "alloc")] #[error(transparent)] - TraitError(#[from] concision_traits::Error), + BoxError(#[from] Box), + #[error(transparent)] + PadError(#[from] crate::utils::pad::PadError), #[error(transparent)] ParamError(#[from] concision_params::ParamsError), #[error(transparent)] InitError(#[from] concision_init::InitError), #[error(transparent)] - ShapeError(#[from] ndarray::ShapeError), + TraitError(#[from] concision_traits::Error), #[cfg(feature = "serde")] #[error(transparent)] DeserializeError(#[from] serde::de::value::Error), @@ -64,31 +61,19 @@ pub enum CommonError { #[error(transparent)] IoError(#[from] std::io::Error), #[error(transparent)] - #[cfg(feature = "rand")] - UniformError(#[from] rand_distr::uniform::Error), + ShapeError(#[from] ndarray::ShapeError), + #[error("Unknown Error: {0}")] + UnknownError(String), } -#[cfg(feature = "alloc")] -mod alloc_err { - use alloc::{boxed::Box, string::String}; - - #[derive(Debug, thiserror::Error)] - pub enum AllocError { - #[error(transparent)] - BoxError(#[from] Box), - #[error("Unknown Error: {0}")] - Unknown(String), - } - - impl From for AllocError { - fn from(value: String) -> Self { - Self::Unknown(value) - } +impl From for Error { + fn from(value: String) -> Self { + Self::UnknownError(value) } +} - impl From<&str> for AllocError { - fn from(value: &str) -> Self { - Self::Unknown(String::from(value)) - } +impl From<&str> for Error { + fn from(value: &str) -> Self { + Self::UnknownError(String::from(value)) } } diff --git a/core/src/models/ex/sample.rs b/core/src/ex/sample.rs similarity index 79% rename from core/src/models/ex/sample.rs rename to core/src/ex/sample.rs index 437d0023..4016f6d1 100644 --- a/core/src/models/ex/sample.rs +++ b/core/src/ex/sample.rs @@ -2,13 +2,14 @@ appellation: model authors: @FL03 */ +#![cfg(feature = "std")] use crate::activate::{ReLUActivation, SigmoidActivation}; use crate::{ DeepModelParams, Error, Forward, Model, ModelFeatures, Norm, Params, StandardModelConfig, Train, }; #[cfg(feature = "rand")] use concision_init::{ - InitRand, + NdInit, rand_distr::{Distribution, StandardNormal}, }; @@ -108,30 +109,25 @@ impl Model for TestModel { } } -impl Forward> for TestModel +impl Forward> for TestModel where A: Float + FromPrimitive + ScalarOperand, D: Dimension, S: Data, - Params: Forward, Output = Array>, + Params: Forward, Output = Array> + + Forward, Output = Array>, { type Output = Array; - fn forward(&self, input: &ArrayBase) -> Option { - let mut output = self - .params() - .input() - .forward_then(&input.to_owned(), |y| y.relu())?; - + fn forward(&self, input: &ArrayBase) -> Self::Output { + // complete the first forward pass using the input layer + let mut output = self.params().input().forward(input).relu(); + // complete the forward pass for each hidden layer for layer in self.params().hidden() { - output = layer.forward_then(&output, |y| y.relu())?; + output = layer.forward(&output).relu(); } - let y = self - .params() - .output() - .forward_then(&output, |y| y.sigmoid())?; - Some(y) + self.params().output().forward(&output).sigmoid() } } @@ -150,10 +146,16 @@ where target: &ArrayBase, ) -> Result { if input.len() != self.layout().input() { - return Err(Error::InvalidInputShape); + return Err(Error::InvalidInputFeatures( + input.len(), + self.layout().input(), + )); } if target.len() != self.layout().output() { - return Err(Error::InvalidOutputShape); + return Err(Error::InvalidTargetFeatures( + target.len(), + self.layout().output(), + )); } // get the learning rate from the model's configuration let lr = self @@ -170,28 +172,15 @@ where let mut activations = Vec::new(); activations.push(input.to_owned()); - let mut output = self - .params() - .input() - .forward(&input) - .expect("Failed to complete the forward pass for the input layer") - .relu(); + let mut output = self.params().input().forward_then(&input, |y| y.relu()); activations.push(output.to_owned()); // collect the activations of the hidden for layer in self.params().hidden() { - output = layer - .forward(&output) - .expect("failed to complete the forward pass for the hidden layer") - .relu(); + output = layer.forward(&output).relu(); activations.push(output.to_owned()); } - output = self - .params() - .output() - .forward(&output) - .expect("Output layer failed to forward propagate") - .sigmoid(); + output = self.params().output().forward(&output).sigmoid(); activations.push(output.to_owned()); // Calculate output layer error @@ -205,8 +194,7 @@ where // Update output weights self.params_mut() .output_mut() - .backward(activations.last().unwrap(), &delta, lr) - .expect("Output failed training..."); + .backward(activations.last().unwrap(), &delta, lr); let num_hidden = self.layout().layers(); // Iterate through hidden layers in reverse order @@ -222,9 +210,7 @@ where }; // Normalize delta to prevent exploding gradients delta /= delta.l2_norm(); - self.params_mut().hidden_mut()[i] - .backward(&activations[i + 1], &delta, lr) - .expect("Hidden failed training..."); + self.params_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr); } /* The delta for the input layer is computed using the weights of the first hidden layer @@ -234,8 +220,7 @@ where delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients self.params_mut() .input_mut() - .backward(&activations[1], &delta, lr) - .expect("failed to backpropagate input layer during training..."); + .backward(&activations[1], &delta, lr); Ok(loss) } @@ -254,15 +239,21 @@ where &mut self, input: &ArrayBase, target: &ArrayBase, - ) -> Result { + ) -> Result { if input.nrows() == 0 || target.nrows() == 0 { - return Err(Error::InvalidBatchSize); + return Err(anyhow::anyhow!("Input and target batches must be non-empty").into()); } if input.ncols() != self.layout().input() { - return Err(Error::InvalidInputShape); + return Err(Error::InvalidInputFeatures( + input.ncols(), + self.layout().input(), + )); } if target.ncols() != self.layout().output() || target.nrows() != input.nrows() { - return Err(Error::InvalidOutputShape); + return Err(Error::InvalidTargetFeatures( + target.ncols(), + self.layout().output(), + )); } let batch_size = input.nrows(); let mut loss = A::zero(); @@ -271,6 +262,13 @@ where loss += match Train::, ArrayView1>::train(self, &x, &e) { Ok(l) => l, Err(err) => { + #[cfg(not(feature = "tracing"))] + eprintln!( + "Training failed for batch {}/{}: {:?}", + i + 1, + batch_size, + err + ); #[cfg(feature = "tracing")] tracing::error!( "Training failed for batch {}/{}: {:?}", diff --git a/core/src/layers/layer.rs b/core/src/layers/layer.rs index 670112e0..687348a2 100644 --- a/core/src/layers/layer.rs +++ b/core/src/layers/layer.rs @@ -2,46 +2,31 @@ Appellation: layer Contrib: @FL03 */ -//! this module defines the [`LayerBase`] struct, a generic representation of a neural network -//! layer essentially wrapping a [`ParamsBase`] with some _activation function_, `F`. -//! - mod impl_layer; -mod impl_layer_repr; - -#[allow(deprecated)] mod impl_layer_deprecated; +mod impl_layer_repr; use super::Activator; -use concision_params::ParamsBase; use concision_traits::Forward; -use ndarray::{DataOwned, Dimension, Ix2, RawData, RemoveAxis, ShapeBuilder}; -/// The [`LayerBase`] struct is a base representation of a neural network layer, essentially -/// binding an activation function, `F`, to a set of parameters, `ParamsBase`. -pub struct LayerBase -where - D: Dimension, - S: RawData, -{ +/// The [`Layer`] implementation works to provide a generic interface for layers within a +/// neural network. It associates an activation function of type `F` with parameters of +/// type `P`. +pub struct Layer { /// the activation function of the layer pub(crate) rho: F, /// the parameters of the layer is an object consisting of both a weight and a bias tensor. - pub(crate) params: ParamsBase, + pub(crate) params: P, } -impl LayerBase -where - D: Dimension, - S: RawData, -{ - /// create a new [`LayerBase`] from the given activation function and parameters. - pub const fn new(rho: F, params: ParamsBase) -> Self { +impl Layer { + /// create a new [`Layer`] from the given activation function and parameters. + pub const fn new(rho: F, params: P) -> Self { Self { rho, params } } - /// create a new [`LayerBase`] from the given parameters assuming the logical default for + /// create a new [`Layer`] from the given parameters assuming the logical default for /// the activation of type `F`. - pub fn from_params(params: ParamsBase) -> Self + pub fn from_params(params: P) -> Self where F: Default, { @@ -50,25 +35,22 @@ where params, } } - /// create a new [`LayerBase`] from the given activation function and shape. - pub fn from_rho(rho: F, shape: Sh) -> Self + /// create a new [`Layer`] from the given activation function and shape. + pub fn from_rho(rho: F) -> Self where - A: Clone + Default, - S: DataOwned, - D: RemoveAxis, - Sh: ShapeBuilder, + P: Default, { Self { rho, - params: ParamsBase::default(shape), + params:

::default(), } } /// returns an immutable reference to the layer's parameters - pub const fn params(&self) -> &ParamsBase { + pub const fn params(&self) -> &P { &self.params } /// returns a mutable reference to the layer's parameters - pub const fn params_mut(&mut self) -> &mut ParamsBase { + pub const fn params_mut(&mut self) -> &mut P { &mut self.params } /// returns an immutable reference to the activation function of the layer @@ -80,37 +62,35 @@ where &mut self.rho } /// consumes the current instance and returns another with the given parameters. - pub fn with_params(self, params: ParamsBase) -> LayerBase + pub fn with_params(self, params: Y) -> Layer where - S2: RawData, - D2: Dimension, + F: Activator, { - LayerBase { + Layer { rho: self.rho, params, } } /// consumes the current instance and returns another with the given activation function. /// This is useful during the creation of the model, when the activation function is not known yet. - pub fn with_rho(self, rho: G) -> LayerBase + pub fn with_rho(self, rho: G) -> Layer where - G: Activator, - F: Activator, - S: RawData, + G: Activator

, + F: Activator

, { - LayerBase { + Layer { rho, params: self.params, } } - pub fn forward(&self, input: &X) -> Option + /// given some input, complete a single forward pass through the layer + pub fn forward(&self, input: &U) -> V where - ParamsBase: Forward, - F: Activator< as Forward>::Output, Output = Y>, - A: Clone, - X: Clone, - Y: Clone, + P: Forward, + F: Activator, + V: Clone, { - Forward::forward(&self.params, input).map(|x| self.rho.activate(x)) + self.params() + .forward_then(input, |y| self.rho().activate(y)) } } diff --git a/core/src/layers/layer/impl_layer.rs b/core/src/layers/layer/impl_layer.rs index 0999f9c7..6c90fca8 100644 --- a/core/src/layers/layer/impl_layer.rs +++ b/core/src/layers/layer/impl_layer.rs @@ -2,100 +2,60 @@ appellation: impl_layer authors: @FL03 */ -use crate::layers::LayerBase; +use crate::layers::Layer; -use crate::layers::{Activator, ActivatorGradient, Layer}; -use concision_params::ParamsBase; +use crate::layers::{Activator, RawLayer}; +use concision_params::{ParamsBase, RawParam}; use concision_traits::Forward; -use ndarray::{Data, Dimension, RawData}; +use ndarray::{DataOwned, Dimension, RawData, RemoveAxis, ShapeBuilder}; -impl core::ops::Deref for LayerBase -where - D: Dimension, - S: RawData, -{ - type Target = ParamsBase; - - fn deref(&self) -> &Self::Target { - &self.params - } -} - -impl core::ops::DerefMut for LayerBase +impl Layer> where + F: Activator, D: Dimension, - S: RawData, + S: RawData, { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.params + /// create a new [`LayerBase`] from the given activation function and shape. + pub fn from_rho_with_shape(rho: F, shape: Sh) -> Self + where + A: Clone + Default, + S: DataOwned, + D: RemoveAxis, + Sh: ShapeBuilder, + { + Self { + rho, + params: ParamsBase::default(shape), + } } } -impl Forward for LayerBase +impl Forward for Layer where - A: Clone, F: Activator, - D: Dimension, - S: Data, - ParamsBase: Forward, + P: RawParam + Forward, { type Output = Y; - fn forward(&self, inputs: &X) -> Option { - let y = self - .params() - .forward(inputs) - .expect("Forward pass failed in LayerBase"); - - Some(self.rho().activate(y)) + fn forward(&self, input: &X) -> Self::Output { + self.rho().activate(self.params().forward(input)) } } -impl Activator for LayerBase +impl RawLayer for Layer where - F: Activator, - D: Dimension, - S: RawData, + F: Activator

, + P: RawParam, { - type Output = V; - - fn activate(&self, x: U) -> Self::Output { - self.rho().activate(x) - } -} - -impl ActivatorGradient for LayerBase -where - F: ActivatorGradient, - D: Dimension, - S: RawData, -{ - type Input = F::Input; - type Delta = F::Delta; - - fn activate_gradient(&self, inputs: F::Input) -> F::Delta { - self.rho().activate_gradient(inputs) - } -} - -impl Layer for LayerBase -where - F: Activator, - D: Dimension, - S: RawData, -{ - type Elem = A; - type Rho = F; - - fn rho(&self) -> &Self::Rho { + fn rho(&self) -> &F { &self.rho } - fn params(&self) -> &ParamsBase { + fn params(&self) -> &P { &self.params } - fn params_mut(&mut self) -> &mut ParamsBase { + fn params_mut(&mut self) -> &mut P { &mut self.params } } diff --git a/core/src/layers/layer/impl_layer_deprecated.rs b/core/src/layers/layer/impl_layer_deprecated.rs index 259c4d94..96d12c65 100644 --- a/core/src/layers/layer/impl_layer_deprecated.rs +++ b/core/src/layers/layer/impl_layer_deprecated.rs @@ -2,14 +2,9 @@ appellation: impl_layer_deprecated authors: @FL03 */ -use crate::layers::LayerBase; +#![allow(deprecated)] -use ndarray::{Dimension, RawData}; +use crate::layers::Layer; #[doc(hidden)] -impl LayerBase -where - D: Dimension, - S: RawData, -{ -} +impl Layer {} diff --git a/core/src/layers/layer/impl_layer_repr.rs b/core/src/layers/layer/impl_layer_repr.rs index 7a1be8f2..79ede5a1 100644 --- a/core/src/layers/layer/impl_layer_repr.rs +++ b/core/src/layers/layer/impl_layer_repr.rs @@ -2,20 +2,14 @@ appellation: impl_layer_repr authors: @FL03 */ -use crate::layers::layer::LayerBase; +use super::Layer; use crate::layers::{Linear, ReLU, Sigmoid, Tanh}; -use concision_params::ParamsBase; -use ndarray::{Dimension, RawData}; -impl LayerBase -where - D: Dimension, - S: RawData, -{ +impl Layer { /// initialize a new [`LayerBase`] using a [`Linear`] activation function and the given /// parameters. - pub const fn linear(params: ParamsBase) -> Self { + pub const fn linear(params: T) -> Self { Self { rho: Linear, params, @@ -23,14 +17,10 @@ where } } -impl LayerBase -where - D: Dimension, - S: RawData, -{ +impl Layer { /// initialize a new [`LayerBase`] using a [`Sigmoid`] activation function and the given /// parameters. - pub const fn sigmoid(params: ParamsBase) -> Self { + pub const fn sigmoid(params: T) -> Self { Self { rho: Sigmoid, params, @@ -38,25 +28,17 @@ where } } -impl LayerBase -where - D: Dimension, - S: RawData, -{ +impl Layer { /// initialize a new [`LayerBase`] using a [`Tanh`] activation function and the given /// parameters. - pub const fn tanh(params: ParamsBase) -> Self { + pub const fn tanh(params: T) -> Self { Self { rho: Tanh, params } } } -impl LayerBase -where - D: Dimension, - S: RawData, -{ +impl Layer { /// initialize a new [`LayerBase`] using a [`ReLU`] activation function and the given - pub const fn relu(params: ParamsBase) -> Self { + pub const fn relu(params: T) -> Self { Self { rho: ReLU, params } } } diff --git a/core/src/layers/mod.rs b/core/src/layers/mod.rs index cdfb8cce..a8aa7ca5 100644 --- a/core/src/layers/mod.rs +++ b/core/src/layers/mod.rs @@ -2,18 +2,23 @@ Appellation: layers Contrib: @FL03 */ -//! This module implments various layers for a neural network +//! This module provides the [`Layer`] implementation along with supporting traits and types. +//! +//! struct, a generic representation of a neural network +//! layer by associating +//! #[doc(inline)] -pub use self::{layer::LayerBase, traits::*, types::*}; +pub use self::{layer::Layer, traits::*, types::*}; -pub(crate) mod layer; -pub mod sequential; +mod layer; + +pub mod seq; pub(crate) mod traits { #[doc(inline)] - pub use self::{activate::*, layers::*}; + pub use self::{activator::*, layers::*}; - mod activate; + mod activator; mod layers; mod store; } @@ -29,3 +34,48 @@ pub(crate) mod prelude { pub use super::layer::*; pub use super::types::*; } + +#[cfg(test)] +mod tests { + use super::*; + use concision_params::Params; + use ndarray::{Array1, array}; + + #[test] + fn test_linear_layer() { + let params = Params::from_elem((3, 2), 0.5_f32); + let layer = Layer::linear(params); + + assert_eq!(layer.params().shape(), &[3, 2]); + + let inputs = Array1::linspace(1.0_f32, 2.0_f32, 3); + println!("{:?}", inputs); + assert_eq!(layer.forward(&inputs), array![2.75, 2.75]); + } + + #[test] + fn test_relu_layer() { + let params = Params::from_elem((3, 2), 0.5_f32); + let layer = Layer::relu(params); + + assert_eq!(layer.params().shape(), &[3, 2]); + + let inputs = Array1::linspace(1.0_f32, 2.0_f32, 3); + assert_eq!(layer.forward(&inputs), array![2.75, 2.75]); + } + #[cfg(feature = "approx")] + #[test] + fn test_tanh_layer() { + let params = Params::from_elem((3, 2), 0.5_f32); + let layer = Layer::tanh(params); + + assert_eq!(layer.params().shape(), &[3, 2]); + + let inputs = Array1::linspace(1.0_f32, 2.0_f32, 3); + approx::assert_abs_diff_eq!( + layer.forward(&inputs), + Array1::from_elem(2, 0.99185973), + epsilon = 1e-6 + ); + } +} diff --git a/core/src/layers/sequential.rs b/core/src/layers/seq.rs similarity index 100% rename from core/src/layers/sequential.rs rename to core/src/layers/seq.rs diff --git a/core/src/layers/traits/activate.rs b/core/src/layers/traits/activator.rs similarity index 64% rename from core/src/layers/traits/activate.rs rename to core/src/layers/traits/activator.rs index 414b0ae3..27849177 100644 --- a/core/src/layers/traits/activate.rs +++ b/core/src/layers/traits/activator.rs @@ -13,11 +13,10 @@ pub trait Activator { /// The [`ActivatorGradient`] trait extends the [`Activator`] trait to include a method for /// computing the gradient of the activation function. pub trait ActivatorGradient: Activator { - type Input; type Delta; /// compute the gradient of some input - fn activate_gradient(&self, input: Self::Input) -> Self::Delta; + fn activate_gradient(&self, input: T) -> Self::Delta; } /* @@ -34,14 +33,13 @@ where } } -impl ActivatorGradient for &T +impl ActivatorGradient for &U where - T: ActivatorGradient, + U: ActivatorGradient, { - type Input = B; - type Delta = C; + type Delta = G; - fn activate_gradient(&self, inputs: Self::Input) -> Self::Delta { + fn activate_gradient(&self, inputs: A) -> Self::Delta { (*self).activate_gradient(inputs) } } @@ -55,16 +53,11 @@ impl Activator for dyn Fn(X) -> Y { } #[cfg(feature = "alloc")] -mod impl_alloc { - use super::Activator; - use alloc::boxed::Box; - - impl Activator for Box> { - type Output = Y; +impl Activator for alloc::boxed::Box> { + type Output = Y; - fn activate(&self, rhs: X) -> Self::Output { - self.as_ref().activate(rhs) - } + fn activate(&self, rhs: X) -> Self::Output { + self.as_ref().activate(rhs) } } @@ -72,7 +65,14 @@ mod impl_alloc { ************* Implementations ************* */ macro_rules! activator { - (@impl $vis:vis struct $name:ident::<$($trait:ident)::*>($method:ident) ) => { + ($( + $vis:vis struct $name:ident.$method:ident where $T:ident: $($trait:ident)::* + );* $(;)?) => { + $( + activator!(@impl $vis struct $name.$method where $T: $($trait)::* ); + )* + }; + (@impl $vis:vis struct $name:ident.$method:ident where $T:ident: $($trait:ident)::* ) => { #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -94,7 +94,6 @@ macro_rules! activator { where U: $($trait)::*, { - type Input = U; type Delta = U::Output; fn activate_gradient(&self, inputs: U) -> Self::Delta { @@ -103,18 +102,13 @@ macro_rules! activator { } } }; - ($( - $vis:vis struct $name:ident::<$($trait:ident)::*>($method:ident) - );* $(;)?) => { - $( - activator!(@impl $vis struct $name::<$($trait)::*>($method)); - )* - }; } activator! { - pub struct Linear::(linear); - pub struct ReLU::(relu); - pub struct Sigmoid::(sigmoid); - pub struct Tanh::(tanh); + pub struct Linear.linear where T: crate::activate::LinearActivation; + pub struct ReLU.relu where T: crate::activate::ReLUActivation; + pub struct Sigmoid.sigmoid where T: crate::activate::SigmoidActivation; + pub struct Tanh.tanh where T: crate::activate::TanhActivation; + pub struct HeavySide.heavyside where T: crate::activate::HeavysideActivation; + pub struct Softmax.softmax where T: crate::activate::SoftmaxActivation; } diff --git a/core/src/layers/traits/layers.rs b/core/src/layers/traits/layers.rs index f3e5b81f..5a109921 100644 --- a/core/src/layers/traits/layers.rs +++ b/core/src/layers/traits/layers.rs @@ -4,55 +4,54 @@ */ use super::{Activator, ActivatorGradient}; -use concision_params::ParamsBase; +use concision_params::{ParamsBase, RawParam}; use concision_traits::{Backward, Forward}; use ndarray::{Data, Dimension, RawData}; +pub trait RawLayer::Elem> +where + F: Activator, + X: RawParam, +{ + /// the activation function of the layer + fn rho(&self) -> &F; + /// returns an immutable reference to the parameters of the layer + fn params(&self) -> &X; + /// returns a mutable reference to the parameters of the layer + fn params_mut(&mut self) -> &mut X; +} /// A generic trait defining the composition of a _layer_ within a neural network. -pub trait Layer +pub trait NdLayer::Elem> where D: Dimension, - S: RawData, + S: RawData, { - /// the type of element used within the layer; typically a floating-point variant like - /// [`f32`] or [`f64`]. - type Elem; /// The type of activator used by the layer; the type must implement [`ActivatorGradient`] - type Rho: Activator; + type Rho: Activator; fn rho(&self) -> &Self::Rho; /// returns an immutable reference to the parameters of the layer fn params(&self) -> &ParamsBase; /// returns a mutable reference to the parameters of the layer fn params_mut(&mut self) -> &mut ParamsBase; -} -/// The [`LayerExt`] trait extends the base [`Layer`] trait with additional methods that -/// are commonly used in neural network layers. It provides methods for setting parameters, -/// performing backward propagation of errors, and completing a forward pass through the layer. -pub trait LayerExt: Layer -where - D: Dimension, - S: RawData, -{ + /// update the layer parameters fn set_params(&mut self, params: ParamsBase) { *self.params_mut() = params; } /// backward propagate error through the layer - fn backward(&mut self, input: X, error: Y, gamma: Self::Elem) -> Option + fn backward(&mut self, input: X, error: Y, gamma: A) where S: Data, - Self: ActivatorGradient, - Self::Elem: Clone, - ParamsBase: Backward, + Self: ActivatorGradient, + A: Clone, + ParamsBase: Backward, { - // compute the delta using the activation function let delta = self.activate_gradient(error); - // apply the backward function of the inherited layer self.params_mut().backward(&input, &delta, gamma) } /// complete a forward pass through the layer - fn forward(&self, input: &X) -> Option + fn forward(&self, input: &X) -> Y where ParamsBase: Forward, Self: Activator, diff --git a/core/src/layers/types/aliases.rs b/core/src/layers/types/aliases.rs index ee670940..198297a2 100644 --- a/core/src/layers/types/aliases.rs +++ b/core/src/layers/types/aliases.rs @@ -3,20 +3,33 @@ authors: @FL03 */ #[cfg(feature = "alloc")] -use crate::layers::Activator; -use crate::layers::{LayerBase, Linear, ReLU, Sigmoid, Tanh}; -#[cfg(feature = "alloc")] -use alloc::boxed::Box; +pub use self::impl_alloc::*; + +use crate::layers::{HeavySide, Layer, Linear, ReLU, Sigmoid, Tanh}; +use concision_params::{Params, ParamsBase}; +use ndarray::Ix2; + +pub type LayerParamsBase = Layer>; + +pub type LayerParams = Layer>; + +/// A type alias for a [`Layer`] configured with a [`Linear`] activation function. +pub type LinearLayer = Layer; +/// A type alias for a [`Layer`] configured with a [`Sigmoid`] activation function. +pub type SigmoidLayer = Layer; +/// A type alias for a [`Layer`] configured with a [`Tanh`] activation function. +pub type TanhLayer = Layer; +/// A type alias for a [`Layer`] configured with a [`ReLU`] activation function. +pub type ReluLayer = Layer; +/// A type alias for a [`Layer`] configured with a [`HeavySide`] activation function. +/// This activation function is also known as the step function. +pub type HeavysideLayer = Layer; #[cfg(feature = "alloc")] -/// A type alias for a [`LayerBase`] configured with a dynamic [`Activator`]. -pub type LayerDyn = LayerBase + 'static>, S, D>; +mod impl_alloc { + use crate::layers::{Activator, Layer}; + use alloc::boxed::Box; -/// A type alias for a [`LayerBase`] configured with a [`Linear`] activation function. -pub type LinearLayer = LayerBase; -/// A type alias for a [`LayerBase`] configured with a [`Sigmoid`] activation function. -pub type SigmoidLayer = LayerBase; -/// A type alias for a [`LayerBase`] configured with a [`Tanh`] activation function. -pub type TanhLayer = LayerBase; -/// A type alias for a [`LayerBase`] configured with a [`ReLU`] activation function. -pub type ReluLayer = LayerBase; + /// A type alias for a [`Layer`] configured with a dynamic [`Activator`]. + pub type LayerDyn = Layer + 'static>, T>; +} diff --git a/core/src/layout.rs b/core/src/layout.rs index 63139e74..989317e2 100644 --- a/core/src/layout.rs +++ b/core/src/layout.rs @@ -4,14 +4,16 @@ */ mod impl_model_features; mod impl_model_format; +mod impl_model_layout; +/// A trait that consumes the caller to create a new instance of [`ModelFeatures`] object. pub trait IntoModelFeatures { fn into_model_features(self) -> ModelFeatures; } /// The [`RawModelLayout`] trait defines a minimal interface for objects capable of representing -/// the _layout_; i.e. the number of input, hidden, and output features of a neural network model -/// containing some number of hidden layers. +/// the _layout_; i.e. the number of input, hidden, and output features of a neural network +/// model containing some number of hidden layers. /// /// **Note**: This trait is implemented for the 3- and 4-tuple consiting of usize elements as /// well as for the `[usize; 3]` and `[usize; 4]` array types. In both these instances, the @@ -26,23 +28,7 @@ pub trait RawModelLayout { fn output(&self) -> usize; /// returns the number of hidden layers within the network fn layers(&self) -> usize; -} -pub trait ModelLayoutMut: RawModelLayout { - /// returns a mutable reference to number of the input features for the model - fn input_mut(&mut self) -> &mut usize; - /// returns a mutable reference to the number of hidden features for the model - fn hidden_mut(&mut self) -> &mut usize; - /// returns a mutable reference to the number of hidden layers for the model - fn layers_mut(&mut self) -> &mut usize; - /// returns a mutable reference to the output features for the model - fn output_mut(&mut self) -> &mut usize; -} - -/// The [`ModelLayout`] trait defines an interface for object capable of representing the -/// _layout_; i.e. the number of input, hidden, and output features of a neural network model -/// containing some number of hidden layers. -pub trait ModelLayout: RawModelLayout + ModelLayoutMut + Clone + core::fmt::Debug { /// the dimension of the input layer; (input, hidden) fn dim_input(&self) -> (usize, usize) { (self.input(), self.hidden()) @@ -71,6 +57,19 @@ pub trait ModelLayout: RawModelLayout + ModelLayoutMut + Clone + core::fmt::Debu fn size_output(&self) -> usize { self.hidden() * self.output() } +} +/// The [`RawModelLayoutMut`] trait defines a mutable interface for objects capable of representing +/// the _layout_; i.e. the number of input, hidden, and output features of +pub trait RawModelLayoutMut: RawModelLayout { + /// returns a mutable reference to number of the input features for the model + fn input_mut(&mut self) -> &mut usize; + /// returns a mutable reference to the number of hidden features for the model + fn hidden_mut(&mut self) -> &mut usize; + /// returns a mutable reference to the number of hidden layers for the model + fn layers_mut(&mut self) -> &mut usize; + /// returns a mutable reference to the output features for the model + fn output_mut(&mut self) -> &mut usize; + #[inline] /// update the number of input features for the model and return a mutable reference to the /// current layout. @@ -101,6 +100,24 @@ pub trait ModelLayout: RawModelLayout + ModelLayoutMut + Clone + core::fmt::Debu } } +/// The [`LayoutExt`] trait defines an interface for object capable of representing the +/// _layout_; i.e. the number of input, hidden, and output features of a neural network model +/// containing some number of hidden layers. +pub trait LayoutExt: RawModelLayout + RawModelLayoutMut + Clone + core::fmt::Debug {} + +/// The [`NetworkDepth`] trait is used to define the depth/kind of a neural network model. +pub trait NetworkDepth { + private!(); +} + +type_tags! { + #[NetworkDepth] + pub enum { + Deep, + Shallow, + } +} + /// The [`ModelFormat`] type enumerates the various formats a neural network may take, either /// shallow or deep, providing a unified interface for accessing the number of hidden features /// and layers in the model. This is done largely for simplicity, as it eliminates the need to @@ -133,6 +150,16 @@ pub struct ModelFeatures { /// the number of output features pub(crate) output: usize, } +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] +pub struct ModelLayout +where + D: NetworkDepth, + F: RawModelLayout, +{ + pub(crate) features: F, + pub(crate) _marker: core::marker::PhantomData, +} /* ************* Implementations ************* @@ -174,7 +201,7 @@ where } } -impl ModelLayout for T where T: ModelLayoutMut + Copy + core::fmt::Debug {} +impl LayoutExt for T where T: RawModelLayoutMut + Copy + core::fmt::Debug {} impl RawModelLayout for (usize, usize, usize) { fn input(&self) -> usize { @@ -207,7 +234,7 @@ impl RawModelLayout for (usize, usize, usize, usize) { } } -impl ModelLayoutMut for (usize, usize, usize, usize) { +impl RawModelLayoutMut for (usize, usize, usize, usize) { fn input_mut(&mut self) -> &mut usize { &mut self.0 } @@ -261,7 +288,7 @@ impl RawModelLayout for [usize; 4] { } } -impl ModelLayoutMut for [usize; 4] { +impl RawModelLayoutMut for [usize; 4] { fn input_mut(&mut self) -> &mut usize { &mut self[0] } @@ -275,3 +302,49 @@ impl ModelLayoutMut for [usize; 4] { &mut self[3] } } + +impl IntoModelFeatures for (usize, usize, usize) { + fn into_model_features(self) -> ModelFeatures { + ModelFeatures { + input: self.0, + inner: ModelFormat::Shallow { hidden: self.1 }, + output: self.2, + } + } +} + +impl IntoModelFeatures for (usize, usize, usize, usize) { + fn into_model_features(self) -> ModelFeatures { + ModelFeatures { + input: self.0, + inner: ModelFormat::Deep { + hidden: self.1, + layers: self.3, + }, + output: self.2, + } + } +} + +impl IntoModelFeatures for [usize; 3] { + fn into_model_features(self) -> ModelFeatures { + ModelFeatures { + input: self[0], + inner: ModelFormat::Shallow { hidden: self[1] }, + output: self[2], + } + } +} + +impl IntoModelFeatures for [usize; 4] { + fn into_model_features(self) -> ModelFeatures { + ModelFeatures { + input: self[0], + inner: ModelFormat::Deep { + hidden: self[1], + layers: self[3], + }, + output: self[2], + } + } +} diff --git a/core/src/layout/impl_model_features.rs b/core/src/layout/impl_model_features.rs index bf9db0dc..b089213f 100644 --- a/core/src/layout/impl_model_features.rs +++ b/core/src/layout/impl_model_features.rs @@ -2,7 +2,7 @@ Appellation: layout Contrib: @FL03 */ -use super::{ModelFeatures, ModelFormat, ModelLayoutMut, RawModelLayout}; +use super::{ModelFeatures, ModelFormat, RawModelLayout, RawModelLayoutMut}; /// verify if the input and hidden dimensions are compatible by checking: /// @@ -13,16 +13,12 @@ fn _verify_input_and_hidden_shape(input: D, hidden: D) -> bool where D: ndarray::Dimension, { - let mut valid = true; - // // check that the hidden dimension is square - // if hidden.ndim() > 1 && hidden.shape().iter().any(|&d| d != hidden.shape()[0]) { - // valid = false; - // } + let lhs = input.as_array_view(); + let rhs = hidden.as_array_view(); // check that the input and hidden dimensions are compatible - if input.ndim() != hidden.ndim() { - valid = false; - } - valid + input.ndim() != hidden.ndim() + && lhs[input.ndim() - 1] != rhs[0] + && rhs.iter().all(|&v| v == rhs[0]) } impl ModelFeatures { @@ -202,7 +198,7 @@ impl RawModelLayout for ModelFeatures { } } -impl ModelLayoutMut for ModelFeatures { +impl RawModelLayoutMut for ModelFeatures { fn input_mut(&mut self) -> &mut usize { self.input_mut() } diff --git a/core/src/layout/impl_model_layout.rs b/core/src/layout/impl_model_layout.rs new file mode 100644 index 00000000..bee305f1 --- /dev/null +++ b/core/src/layout/impl_model_layout.rs @@ -0,0 +1,88 @@ +/* + Appellation: impl_model_layout + Created At: 2025.12.09:07:46:57 + Contrib: @FL03 +*/ +use super::ModelLayout; + +use crate::layout::{Deep, NetworkDepth, RawModelLayout, Shallow}; + +impl ModelLayout +where + F: RawModelLayout, + D: NetworkDepth, +{ + /// creates a new instance of [`ModelLayout`] using the given features + pub const fn new(features: F) -> Self { + Self { + features, + _marker: core::marker::PhantomData::, + } + } + /// returns a reference to the features of the model layout + pub const fn features(&self) -> &F { + &self.features + } + /// returns a mutable reference to the features of the model layout + pub const fn features_mut(&mut self) -> &mut F { + &mut self.features + } + /// returns a reference to the input of the model layout + pub fn input(&self) -> usize { + self.features().input() + } + /// returns a reference to the output of the model layout + pub fn output(&self) -> usize { + self.features().output() + } + /// returns a reference to the hidden features of the model layout + pub fn hidden(&self) -> usize { + self.features().hidden() + } + /// returns a reference to the depth, or number of hidden layers, of the model + pub fn layers(&self) -> usize { + self.features().layers() + } +} + +impl core::ops::Deref for ModelLayout +where + F: RawModelLayout, + D: NetworkDepth, +{ + type Target = F; + + fn deref(&self) -> &Self::Target { + &self.features + } +} + +impl core::ops::DerefMut for ModelLayout +where + F: RawModelLayout, + D: NetworkDepth, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.features + } +} + +impl ModelLayout +where + F: RawModelLayout, +{ + /// creates a new instance of [`ModelLayout`] using the given features + pub const fn deep(features: F) -> Self { + Self::new(features) + } +} + +impl ModelLayout +where + F: RawModelLayout, +{ + /// returns a new instance of the model layout using the given features and a shallow depth + pub const fn shallow(features: F) -> Self { + Self::new(features) + } +} diff --git a/core/src/lib.rs b/core/src/lib.rs index cee6a451..3c2ec8eb 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -29,7 +29,9 @@ clippy::missing_safety_doc, clippy::module_inception, clippy::needless_doctest_main, - clippy::upper_case_acronyms + clippy::should_implement_trait, + clippy::upper_case_acronyms, + rustdoc::redundant_explicit_links )] #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(feature = "nightly", feature(allocator_api))] @@ -42,13 +44,6 @@ compiler_error! { #[cfg(feature = "alloc")] extern crate alloc; -#[cfg(feature = "rand")] -#[doc(no_inline)] -pub use rand; -#[cfg(feature = "rand")] -#[doc(no_inline)] -pub use rand_distr; - /// this module establishes generic random initialization routines for models, params, and /// tensors. #[doc(inline)] @@ -65,10 +60,12 @@ pub use concision_traits::prelude::*; #[macro_use] pub(crate) mod macros { + #[macro_use] + pub mod config; #[macro_use] pub mod seal; #[macro_use] - pub mod config; + pub mod units; } pub mod activate; @@ -80,14 +77,16 @@ pub mod models; pub mod nn; pub mod utils; -pub mod types { - //! Core types supporting the `cnc` framework. +#[doc(hidden)] +pub mod ex { + pub mod sample; } + // re-exports #[doc(inline)] pub use self::{ - activate::prelude::*, config::prelude::*, error::*, layout::*, models::prelude::*, - utils::prelude::*, + activate::prelude::*, config::prelude::*, error::*, layers::Layer, layout::*, + models::prelude::*, utils::prelude::*, }; // prelude #[doc(hidden)] @@ -97,6 +96,7 @@ pub mod prelude { pub use concision_traits::prelude::*; pub use crate::activate::prelude::*; + pub use crate::config::prelude::*; pub use crate::layers::prelude::*; pub use crate::layout::*; pub use crate::models::prelude::*; diff --git a/core/src/macros/config.rs b/core/src/macros/config.rs index 2e6250d1..fe3a053e 100644 --- a/core/src/macros/config.rs +++ b/core/src/macros/config.rs @@ -12,9 +12,7 @@ macro_rules! config { ( $(#[$attr:meta])* - $vis:vis struct $name:ident { - $($field:ident: $type:ty),* $(,)? - } + $vis:vis struct $name:ident {$($field:ident: $type:ty),* $(,)?} ) => { $(#[$attr])* #[cfg_attr( @@ -28,6 +26,9 @@ macro_rules! config { $($field: $type),* } + config!(@impl $vis struct $name {$($field: $type),*}); + }; + (@impl $vis:vis struct $name:ident {$($field:ident: $type:ty),* $(,)?}) => { impl $name { pub fn new() -> Self { Self { @@ -45,7 +46,7 @@ macro_rules! config { &mut self.$field } /// update the current value of the field and return a mutable reference to self - pub fn [](&mut self, value: $type) -> &mut Self { + pub const fn [](&mut self, value: $type) -> &mut Self { self.$field = value; self } diff --git a/core/src/macros/units.rs b/core/src/macros/units.rs new file mode 100644 index 00000000..2a8fc7ab --- /dev/null +++ b/core/src/macros/units.rs @@ -0,0 +1,31 @@ +/* + Appellation: units + Created At: 2025.12.08:19:44:33 + Contrib: @FL03 +*/ + +macro_rules! type_tags { + (#[$tgt:ident] $vis:vis $s:ident {$($name:ident),* $(,)?}) => { + $( + type_tags!(@impl #[$tgt] $vis $s $name); + )* + }; + (@impl #[$tgt:ident] $vis:vis enum $name:ident) => { + #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] + #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] + $vis enum $name {} + + impl $tgt for $name { + seal!(); + } + }; + (@impl #[$tgt:ident] $vis:vis struct $name:ident) => { + #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] + #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] + $vis struct $name; + + impl $tgt for $name { + seal!(); + } + }; +} diff --git a/core/src/models/impls/impl_model_params_rand.rs b/core/src/models/impls/impl_model_params_rand.rs index b76821a2..519ac380 100644 --- a/core/src/models/impls/impl_model_params_rand.rs +++ b/core/src/models/impls/impl_model_params_rand.rs @@ -2,11 +2,11 @@ appellation: impl_model_params_rand authors: @FL03 */ - use crate::models::{DeepParamsBase, ShallowParamsBase}; use crate::ModelFeatures; -use concision_init::{InitRand, distr as init}; +use concision_init::distr as init; +use concision_init::{NdInit, rand_distr}; use concision_params::ParamsBase; use ndarray::{DataOwned, Ix2}; use num_traits::{Float, FromPrimitive}; diff --git a/core/src/models/impls/impl_params_deep.rs b/core/src/models/impls/impl_params_deep.rs index 3830841c..8740e642 100644 --- a/core/src/models/impls/impl_params_deep.rs +++ b/core/src/models/impls/impl_params_deep.rs @@ -102,19 +102,16 @@ where /// sequentially forwards the input through the model without any activations or other /// complexities in-between. not overly usefuly, but it is here for completeness #[inline] - pub fn forward(&self, input: &X) -> Option + pub fn forward(&self, input: &X) -> Y where A: Clone, S: Data, ParamsBase: Forward + Forward, { - // forward the input through the input layer - let mut output = self.input().forward(input)?; - // forward the input through each of the hidden layers - for layer in self.hidden() { - output = layer.forward(&output)?; - } - // finally, forward the output through the output layer + let mut output = self.input().forward(input); + self.hidden().into_iter().for_each(|layer| { + output = layer.forward(&output); + }); self.output().forward(&output) } } diff --git a/core/src/models/impls/impl_params_shallow.rs b/core/src/models/impls/impl_params_shallow.rs index f4edf9d9..21517583 100644 --- a/core/src/models/impls/impl_params_shallow.rs +++ b/core/src/models/impls/impl_params_shallow.rs @@ -82,19 +82,15 @@ where } } /// forward input through the controller network - pub fn forward(&self, input: &Array1) -> Option> + pub fn forward(&self, input: &Array1) -> Array1 where A: Float + ScalarOperand, S: Data, { - // forward the input through the input layer; activate using relu - let mut output = self.input().forward(input)?.relu(); - // forward the input through the hidden layer(s); activate using relu - output = self.hidden().forward(&output)?.relu(); - // forward the input through the output layer; activate using sigmoid - output = self.output().forward(&output)?.sigmoid(); - - Some(output) + use concision_traits::Forward; + let mut output = self.input().forward_then(input, |x| x.relu()); + output = self.hidden().forward_then(&output, |x| x.relu()); + self.output().forward_then(&output, |x| x.sigmoid()) } } diff --git a/core/src/models/mod.rs b/core/src/models/mod.rs index c0f576dd..b3e6bde0 100644 --- a/core/src/models/mod.rs +++ b/core/src/models/mod.rs @@ -10,12 +10,6 @@ pub use self::{model_params::*, traits::*, types::*}; pub mod model_params; -#[doc(hidden)] -pub mod ex { - #[cfg(all(feature = "rand", feature = "std"))] - pub mod sample; -} - mod impls { mod impl_model_params; mod impl_params_deep; diff --git a/core/src/models/model_params.rs b/core/src/models/model_params.rs index 88626a23..c3d9c197 100644 --- a/core/src/models/model_params.rs +++ b/core/src/models/model_params.rs @@ -11,16 +11,11 @@ use crate::{DeepModelRepr, RawHidden}; /// The [`ModelParamsBase`] object is a generic container for storing the parameters of a /// neural network, regardless of the layout (e.g. shallow or deep). This is made possible /// through the introduction of a generic hidden layer type, `H`, that allows us to define -/// aliases and additional traits for contraining the hidden layer type. That being said, we -/// don't reccoment using this type directly, but rather use the provided type aliases such as -/// [`DeepModelParams`] or [`ShallowModelParams`] or their owned variants. These provide a much -/// more straighforward interface for typing the parameters of a neural network. We aren't too -/// worried about the transmutation between the two since users desiring this ability should -/// simply stick with a _deep_ representation, initializing only a single layer within the -/// respective container. +/// aliases and additional traits for contraining the hidden layer type. Additionally, the +/// structure enables the introduction of common accessors and initialization routines. /// -/// This type also enables us to define a set of common initialization routines and introduce -/// other standards for dealing with parameters in a neural network. +/// With that in mind, we don't reccomend using the implementation directly, rather, leverage +/// a type alias that best suites your use case (e.g. owned parameters, arc parameters, etc.). pub struct ModelParamsBase::Elem> where D: Dimension, diff --git a/core/src/models/traits/model.rs b/core/src/models/traits/model.rs index ee238df6..7083cf2d 100644 --- a/core/src/models/traits/model.rs +++ b/core/src/models/traits/model.rs @@ -3,7 +3,7 @@ authors: @FL03 */ use crate::config::ModelConfiguration; -use crate::{DeepModelParams, ModelLayout, RawModelLayout}; +use crate::{DeepModelParams, LayoutExt, RawModelLayout}; use concision_params::Params; use concision_traits::Predict; @@ -15,7 +15,7 @@ pub trait Model { /// The type of configuration used for the model type Config: ModelConfiguration; /// The type of [`ModelLayout`] used by the model for this implementation. - type Layout: ModelLayout; + type Layout: LayoutExt; /// returns an immutable reference to the models configuration; this is typically used to /// access the models hyperparameters (i.e. learning rate, momentum, etc.) and other /// related control parameters. @@ -39,7 +39,7 @@ pub trait Model { /// By default, the trait simply passes each output from one layer to the next, however, /// custom models will likely override this method to inject activation methods and other /// related logic - fn predict(&self, inputs: &U) -> Option + fn predict(&self, inputs: &U) -> V where Self: Predict, { @@ -132,6 +132,6 @@ pub trait ModelExt: Model { impl ModelExt for M where M: Model, - M::Layout: ModelLayout, + M::Layout: LayoutExt, { } diff --git a/core/src/nn/mod.rs b/core/src/nn/mod.rs index 407c79a1..1773c040 100644 --- a/core/src/nn/mod.rs +++ b/core/src/nn/mod.rs @@ -6,22 +6,47 @@ //! This module provides network specific implementations and traits supporting the development //! of neural network models. //! -#[doc(inline)] -pub use self::{neural_network::*, types::*}; -mod neural_network; +pub(crate) mod prelude { + pub use super::{NetworkConsts, NeuralNetwork, NeuralNetworkParams}; +} -mod traits {} +use ndarray::{Dimension, RawData}; + +pub trait NeuralNetworkParams::Elem> +where + D: Dimension, + S: RawData, +{ +} -mod types { - #[doc(inline)] - pub use self::depth::*; +/// The [`NeuralNetwork`] trait is used to define the network itself as well as each of its +/// constituent parts. +pub trait NeuralNetwork::Elem> +where + D: Dimension, + S: RawData, +{ + /// The context of the neural network defines any additional information required for its operation. + type Ctx; + /// The configuration of the neural network defines its architecture and hyperparameters. + type Config; + /// The parameters of the neural network define its weights and biases. + type Params<_S, _D>: NeuralNetworkParams<_S, _D, A> + where + _S: RawData, + _D: Dimension; - mod depth; + /// returns a reference to the network configuration; + fn config(&self) -> &Self::Config; + + fn params(&self) -> &Self::Params; + + fn params_mut(&mut self) -> &mut Self::Params; } -#[allow(unused)] -pub(crate) mod prelude { - pub use super::traits::*; - pub use super::types::*; +/// A trait defining common constants for neural networks. +pub trait NetworkConsts { + const NAME: &'static str; + const VERSION: &'static str; } diff --git a/core/src/nn/neural_network.rs b/core/src/nn/neural_network.rs deleted file mode 100644 index 7f38d443..00000000 --- a/core/src/nn/neural_network.rs +++ /dev/null @@ -1,57 +0,0 @@ -/* - Appellation: neural_network - Created At: 2025.11.28:15:01:28 - Contrib: @FL03 -*/ -use super::{Deep, NetworkDepth, Shallow}; -use crate::config::ModelConfiguration; -use ndarray::{Dimension, RawData}; - -/// The [`NeuralNetwork`] trait defines a generic interface for neural network models. -pub trait NeuralNetwork::Elem> -where - D: Dimension, - S: RawData, -{ - type Config: ModelConfiguration; - type Depth: NetworkDepth; - - /// returns a reference to the network configuration; - fn config(&self) -> &Self::Config; -} - -pub trait DeepNeuralNetwork::Elem>: - NeuralNetwork -where - D: Dimension, - S: RawData, -{ - private!(); -} - -pub trait ShallowNeuralNetwork::Elem>: - NeuralNetwork -where - D: Dimension, - S: RawData, -{ - private!(); -} - -impl DeepNeuralNetwork for N -where - D: Dimension, - S: RawData, - N: NeuralNetwork, -{ - seal!(); -} - -impl ShallowNeuralNetwork for N -where - D: Dimension, - S: RawData, - N: NeuralNetwork, -{ - seal!(); -} diff --git a/core/src/nn/types/depth.rs b/core/src/nn/types/depth.rs deleted file mode 100644 index 3bba63d6..00000000 --- a/core/src/nn/types/depth.rs +++ /dev/null @@ -1,35 +0,0 @@ -/* - Appellation: depth - Created At: 2025.11.28:15:03:02 - Contrib: @FL03 -*/ - -/// The [`NetworkDepth`] trait is used to define the depth/kind of a neural network model. -pub trait NetworkDepth { - private!(); -} - -macro_rules! network_format { - (#[$tgt:ident] $vis:vis enum {$($name:ident),* $(,)?}) => { - $( - network_format!(@impl #[$tgt] $vis $name); - )* - }; - (@impl #[$tgt:ident] $vis:vis $name:ident) => { - #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] - #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] - $vis enum $name {} - - impl $tgt for $name { - seal!(); - } - }; -} - -network_format! { - #[NetworkDepth] - pub enum { - Deep, - Shallow, - } -} diff --git a/core/src/utils.rs b/core/src/utils.rs index 5bb0b737..3dfe5762 100644 --- a/core/src/utils.rs +++ b/core/src/utils.rs @@ -7,19 +7,13 @@ #[doc(inline)] pub use self::prelude::*; -#[cfg(feature = "signal")] -pub use self::fft::prelude::*; - -#[cfg(feature = "signal")] -pub mod fft; - -pub(crate) mod arith; -pub(crate) mod dropout; -pub(crate) mod gradient; -pub(crate) mod norm; -pub(crate) mod pad; -pub(crate) mod patterns; -pub(crate) mod tensor; +mod arith; +mod dropout; +mod gradient; +mod norm; +pub mod pad; +mod patterns; +mod tensor; pub(crate) mod prelude { pub use super::arith::*; @@ -29,7 +23,4 @@ pub(crate) mod prelude { pub use super::pad::*; pub use super::patterns::*; pub use super::tensor::*; - - #[cfg(feature = "signal")] - pub use super::fft::prelude::*; } diff --git a/core/src/utils/dropout.rs b/core/src/utils/dropout.rs index 61d53f24..99240852 100644 --- a/core/src/utils/dropout.rs +++ b/core/src/utils/dropout.rs @@ -50,7 +50,7 @@ impl Default for Dropout { #[cfg(feature = "rand")] mod impl_rand { use super::*; - use concision_init::InitRand; + use concision_init::NdInit; use concision_traits::Forward; use ndarray::{Array, ArrayBase, DataOwned, Dimension, ScalarOperand}; use num_traits::Num; @@ -80,8 +80,8 @@ mod impl_rand { { type Output = ::Output; - fn forward(&self, input: &U) -> Option { - Some(input.dropout(self.p)) + fn forward(&self, input: &U) -> Self::Output { + input.dropout(self.p) } } } diff --git a/core/src/utils/fft/mod.rs b/core/src/utils/fft/mod.rs deleted file mode 100644 index 2f2f236e..00000000 --- a/core/src/utils/fft/mod.rs +++ /dev/null @@ -1,52 +0,0 @@ -/* - Appellation: fft - Created At: 2025.11.26:13:22:07 - Contrib: @FL03 -*/ -//! this module implements the custom fast-fourier transform (FFT) algorithm -#![cfg(feature = "complex")] - -#[doc(inline)] -pub use self::{types::prelude::*, utils::*}; - -/// this module implements the methods for the fast-fourier transform (FFT) module -pub mod utils; - -pub mod types { - #[doc(inline)] - pub use self::prelude::*; - - pub mod mode; - pub mod plan; - - pub(crate) mod prelude { - pub use super::mode::*; - pub use super::plan::*; - } -} - -pub(crate) mod prelude { - pub use super::utils::*; -} - -/// The [`DFT`] trait establishes a common interface for discrete Fourier transform -/// implementations. -pub trait DFT { - type Output; - - fn dft(&self) -> Self::Output; -} - -#[cfg(test)] -mod tests { - use super::FftPlan; - use super::utils::fft_permutation; - - #[test] - fn test_plan() { - let samples = 16; - - let plan = FftPlan::new(samples).build(); - assert_eq!(plan.plan(), fft_permutation(16).as_slice()); - } -} diff --git a/core/src/utils/fft/types/mode.rs b/core/src/utils/fft/types/mode.rs deleted file mode 100644 index 89c97059..00000000 --- a/core/src/utils/fft/types/mode.rs +++ /dev/null @@ -1,122 +0,0 @@ -/* - Appellation: mode - Contrib: FL03 -*/ - -pub trait RawField { - private!(); -} - -macro_rules! impl_raw_private { - - (impl$(<$($T:ident),*>)? $trait:ident for $name:ident$(<$($V:ident),*>)? $(where $($rest:tt)*)?) => { - impl$(<$($T),*>)? $trait for $name$(<$($V),*>)? $(where $($rest)*)? { - seal!(); - } - }; -} - -macro_rules! toggle { - (@impl $vis:vis enum $name:ident) => { - #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] - #[cfg_attr( - feature = "serde", - derive(serde_derive::Deserialize, serde_derive::Serialize), - serde(rename_all = "lowercase", untagged) - )] - pub enum $name {} - }; - ($vis:vis enum { $($name:ident),* $(,)? }) => { - $( - toggle!(@impl $vis enum $name); - impl_raw_private!(impl RawField for $name); - )* - }; -} - -toggle! { - pub enum { - C, - R - } -} - -/// -#[derive( - Clone, - Copy, - Debug, - Default, - Eq, - Hash, - Ord, - PartialEq, - PartialOrd, - strum::AsRefStr, - strum::Display, - strum::EnumCount, - strum::EnumIs, - strum::EnumIter, - strum::EnumString, - strum::VariantArray, - strum::VariantNames, - variants::VariantConstructors, -)] -#[cfg_attr( - feature = "serde", - derive(serde_derive::Deserialize, serde_derive::Serialize), - serde(rename_all = "lowercase", untagged) -)] -#[strum(serialize_all = "lowercase")] -pub enum FftMode { - #[default] - Complex, - Real, -} - -/// -#[derive( - Clone, - Copy, - Debug, - Default, - Eq, - Hash, - Ord, - PartialEq, - PartialOrd, - strum::AsRefStr, - strum::Display, - strum::EnumCount, - strum::EnumIs, - strum::EnumIter, - strum::EnumString, - strum::VariantArray, - strum::VariantNames, - variants::VariantConstructors, -)] -#[cfg_attr( - feature = "serde", - derive(serde_derive::Deserialize, serde_derive::Serialize), - serde(rename_all = "lowercase", untagged) -)] -#[strum(serialize_all = "lowercase")] -pub enum FftDirection { - #[default] - Forward = 0, - Inverse = 1, -} - -impl From for FftDirection { - fn from(direction: usize) -> Self { - match direction % 2 { - 0 => Self::Forward, - _ => Self::Inverse, - } - } -} -impl From for usize { - fn from(direction: FftDirection) -> Self { - direction as usize - } -} diff --git a/core/src/utils/fft/types/plan.rs b/core/src/utils/fft/types/plan.rs deleted file mode 100644 index 321a9222..00000000 --- a/core/src/utils/fft/types/plan.rs +++ /dev/null @@ -1,108 +0,0 @@ -/* - Appellation: plan - Contrib: FL03 -*/ -use crate::utils::fft::fft_permutation; - -#[cfg(feature = "alloc")] -use alloc::vec::{self, Vec}; - -#[derive(Clone, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] -#[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] -pub struct FftPlan { - len: usize, - plan: Vec, -} - -impl FftPlan { - pub fn new(len: usize) -> Self { - Self { - len, - plan: Vec::with_capacity(len), - } - } - - pub fn build(self) -> Self { - let plan = fft_permutation(self.len); - Self { plan, ..self } - } - - pub fn clear(&mut self) { - self.len = 0; - self.plan.clear(); - } - - pub fn get(&self, index: usize) -> Option<&usize> { - self.plan().get(index) - } - - pub fn iter<'a>(&'a self) -> core::slice::Iter<'a, usize> { - self.plan().iter() - } - - pub fn len(&self) -> usize { - self.len - } - - pub fn plan(&self) -> &[usize] { - &self.plan - } - - pub fn set(&mut self, len: usize) { - self.len = len; - self.plan = Vec::with_capacity(len); - } - - pub fn with(self, len: usize) -> Self { - Self { - len, - plan: Vec::with_capacity(len), - } - } -} - -impl AsRef<[usize]> for FftPlan { - fn as_ref(&self) -> &[usize] { - &self.plan - } -} - -impl AsMut<[usize]> for FftPlan { - fn as_mut(&mut self) -> &mut [usize] { - &mut self.plan - } -} - -impl Extend for FftPlan { - fn extend>(&mut self, iter: T) { - self.plan.extend(iter); - } -} - -impl FromIterator for FftPlan { - fn from_iter>(iter: T) -> Self { - let plan = Vec::from_iter(iter); - Self { - len: plan.len(), - plan, - } - } -} - -impl IntoIterator for FftPlan { - type Item = usize; - type IntoIter = vec::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.plan.into_iter() - } -} - -impl<'a> IntoIterator for &'a mut FftPlan { - type Item = &'a mut usize; - type IntoIter = core::slice::IterMut<'a, usize>; - - fn into_iter(self) -> Self::IntoIter { - self.plan.iter_mut() - } -} diff --git a/core/src/utils/fft/utils.rs b/core/src/utils/fft/utils.rs deleted file mode 100644 index 39616ee5..00000000 --- a/core/src/utils/fft/utils.rs +++ /dev/null @@ -1,198 +0,0 @@ -/* - Appellation: utils - Contrib: FL03 -*/ -use super::FftPlan; -use concision_traits::AsComplex; -use num_complex::{Complex, ComplexFloat}; -use num_traits::{Float, FloatConst, NumAssignOps, NumCast, NumOps}; - -pub(crate) fn fft_angle(n: usize) -> T -where - T: FloatConst + NumCast + NumOps, -{ - T::TAU() / T::from(n).unwrap() -} - -/// Computes the Fast Fourier Transform of a one-dimensional, complex-valued signal. -pub fn fft(input: impl AsRef<[S]>, permute: &FftPlan) -> Vec> -where - S: ComplexFloat, - S::Real: Float + FloatConst, - Complex: ComplexFloat + NumOps + NumOps, -{ - // - let input = input.as_ref(); - // - let n = input.len(); - // initialize the result vector - let mut result = Vec::with_capacity(n); - // store the input values in the result vector according to the permutation - for position in permute.clone().into_iter() { - let arg = input[position]; - result.push(Complex::new(arg.re(), arg.im())); - } - let mut segment: usize = 1; - while segment < n { - segment <<= 1; - // compute the angle of the complex number - let angle = fft_angle::(segment); - // compute the radius of the complex number (length) - let radius = Complex::new(angle.cos(), angle.sin()); - // iterate over the signal in segments of length `segment` - for start in (0..n).step_by(segment) { - let mut w = Complex::new(T::one(), T::zero()); - for position in start..(start + segment / 2) { - let a = result[position]; - let b = result[position + segment / 2] * w; - result[position] = a + b; - result[position + segment / 2] = a - b; - w = w * radius; - } - } - } - result -} - -/// Computes the Fast Fourier Transform of an one-dimensional, real-valued signal. -/// TODO: Optimize the function to avoid unnecessary computation. -pub fn rfft(input: impl AsRef<[T]>, input_permutation: impl AsRef<[usize]>) -> Vec> -where - T: Float + FloatConst, - Complex: ComplexFloat + NumAssignOps, -{ - // create a reference to the input - let input = input.as_ref(); - // fetch the length of the input - let n = input.len(); - // compute the size of the result vector - let size = (n - (n % 2)) / 2 + 1; - // initialize the output vector - let mut store = Vec::with_capacity(size); - // store the input values in the result vector according to the permutation - for position in input_permutation.as_ref() { - store.push(input[*position].as_re()); - } - let mut segment: usize = 1; - while segment < n { - segment <<= 1; - // compute the angle of the complex number - let angle = fft_angle::(segment); - // compute the radius of the complex number (length) - let radius = Complex::new(angle.cos(), angle.sin()); - for start in (0..n).step_by(segment) { - let mut w = Complex::new(T::one(), T::zero()); - for position in start..(start + segment / 2) { - let a = store[position]; - let b = store[position + segment / 2] * w; - store[position] = a + b; - store[position + segment / 2] = a - b; - w *= radius; - } - } - } - store - .iter() - .cloned() - .filter(|x| x.im() >= T::zero()) - .collect() -} -/// Computes the Inverse Fast Fourier Transform of an one-dimensional, complex-valued signal. -pub fn ifft(input: &[S], input_permutation: &FftPlan) -> Vec> -where - S: ComplexFloat, - T: Float + FloatConst, - Complex: ComplexFloat + NumOps + NumOps, -{ - let n = input.len(); - let mut result = Vec::with_capacity(n); - for position in input_permutation.clone().into_iter() { - let arg = input[position]; - result.push(Complex::new(arg.re(), arg.im())); - } - let mut length: usize = 1; - while length < n { - length <<= 1; - let angle = fft_angle::(length).neg(); - let radius = Complex::new(T::cos(angle), T::sin(angle)); // w_len - for start in (0..n).step_by(length) { - let mut w = Complex::new(T::one(), T::zero()); - for position in start..(start + length / 2) { - let a = result[position]; - let b = result[position + length / 2] * w; - result[position] = a + b; - result[position + length / 2] = a - b; - w = w * radius; - } - } - } - let scale = T::from(n).unwrap().recip(); - result.iter().map(|x| *x * scale).collect() -} -/// Computes the Inverse Fast Fourier Transform of an one-dimensional, real-valued signal. -/// TODO: Fix the function; currently fails to compute the correct result -pub fn irfft(input: &[Complex], plan: &FftPlan) -> Vec -where - T: Float + FloatConst, - Complex: ComplexFloat + NumAssignOps, -{ - let n = input.len(); - let mut result = vec![Complex::new(T::zero(), T::zero()); n]; - - for position in plan.clone().into_iter() { - result.push(input[position]); - } - // for res in result.clone() { - // if res.im() > T::zero() { - // result.push(res.conj()); - // } - // } - // segment length - let mut segment: usize = 1; - while segment < n { - segment <<= 1; - // compute the angle of the complex number - let angle = fft_angle::(segment).neg(); - // compute the radius of the complex number (length) - let radius = Complex::new(T::cos(angle), T::sin(angle)); - for start in (0..n).step_by(segment) { - let mut w = Complex::new(T::one(), T::zero()); - for position in start..(start + segment / 2) { - let a = result[position]; - let b = result[position + segment / 2] * w; - result[position] = a + b; - result[position + segment / 2] = a - b; - w *= radius; - } - } - } - let scale = T::from(n).unwrap().recip(); - result.iter().map(|x| x.re() * scale).collect() -} - -#[doc(hidden)] -/// Generates a permutation for the Fast Fourier Transform. -pub(crate) fn fft_permutation(length: usize) -> Vec { - let mut result = Vec::new(); - result.reserve_exact(length); - for i in 0..length { - result.push(i); - } - let mut reverse = 0_usize; - let mut position = 1_usize; - while position < length { - let mut bit = length >> 1; - while bit & reverse != 0 { - reverse ^= bit; - bit >>= 1; - } - reverse ^= bit; - // This is equivalent to adding 1 to a reversed number - if position < reverse { - // Only swap each element once - result.swap(position, reverse); - } - position += 1; - } - result -} diff --git a/core/src/utils/tensor.rs b/core/src/utils/tensor.rs index e2730d04..7fe15137 100644 --- a/core/src/utils/tensor.rs +++ b/core/src/utils/tensor.rs @@ -2,23 +2,26 @@ Appellation: tensor Contrib: FL03 */ -pub use self::{generators::*, stack::*}; -use ndarray::*; +#[cfg(feature = "alloc")] +pub use self::impl_alloc::*; +use ndarray::ScalarOperand; +use ndarray::prelude::*; use num_traits::{NumAssign, Zero}; -#[cfg(feature = "alloc")] -/// Creates an n-dimensional array from an iterator of n dimensional arrays. -pub fn concat_iter(axis: usize, iter: impl IntoIterator>) -> Array +pub fn genspace(features: usize) -> Array1 { + Array1::from_iter((0..features).map(|x| T::from(x).unwrap())) +} + +pub fn linarr(dim: impl Clone + IntoDimension) -> Result, ShapeError> where - D: RemoveAxis, - T: Clone, + A: Float, + D: Dimension, { - let mut arr = iter.into_iter().collect::>(); - let mut out = arr.pop().unwrap(); - for i in arr { - out = concatenate!(Axis(axis), out, i); - } - out + let dim = dim.into_dimension(); + let n = dim.size(); + Array::linspace(A::zero(), A::from(n - 1).unwrap(), n) + .to_shape(dim) + .map(|x| x.to_owned()) } pub fn inverse(matrix: &Array2) -> Option> @@ -65,6 +68,22 @@ where Some(inverted.to_owned()) } +/// Creates a larger array from an iterator of smaller arrays. +pub fn stack_iter(iter: impl IntoIterator>) -> Array2 +where + T: Clone + Zero, +{ + let mut iter = iter.into_iter(); + let first = iter.next().unwrap(); + let shape = [iter.size_hint().0 + 1, first.len()]; + let mut res = Array2::::zeros(shape); + res.slice_mut(s![0, ..]).assign(&first); + for (i, s) in iter.enumerate() { + res.slice_mut(s![i + 1, ..]).assign(&s); + } + res +} + /// Returns the lower triangular portion of a matrix. pub fn tril(a: &Array2) -> Array2 where @@ -92,52 +111,36 @@ where out } -pub(crate) mod generators { - use ndarray::{Array, Array1, Dimension, IntoDimension, ShapeError}; - use num_traits::{Float, NumCast}; - - pub fn genspace(features: usize) -> Array1 { - Array1::from_iter((0..features).map(|x| T::from(x).unwrap())) - } +use ndarray::{Array, Array1, Dimension, IntoDimension, ShapeError}; +use num_traits::{Float, NumCast}; - pub fn linarr(dim: impl Clone + IntoDimension) -> Result, ShapeError> - where - A: Float, - D: Dimension, - { - let dim = dim.into_dimension(); - let n = dim.size(); - Array::linspace(A::zero(), A::from(n - 1).unwrap(), n) - .to_shape(dim) - .map(|x| x.to_owned()) - } -} - -pub(crate) mod stack { - #[cfg(feature = "alloc")] +#[cfg(feature = "alloc")] +pub(crate) mod impl_alloc { use alloc::vec::Vec; - use ndarray::{Array1, Array2, s}; - use num_traits::Num; - /// Creates a larger array from an iterator of smaller arrays. - pub fn stack_iter(iter: impl IntoIterator>) -> Array2 + use ndarray::{Array, Array1, Array2, Axis, RemoveAxis, concatenate, s}; + use num_traits::Zero; + + /// Creates an n-dimensional array from an iterator of n dimensional arrays. + pub fn concat_iter( + axis: usize, + iter: impl IntoIterator>, + ) -> Array where - T: Clone + Num, + D: RemoveAxis, + T: Clone, { - let mut iter = iter.into_iter(); - let first = iter.next().unwrap(); - let shape = [iter.size_hint().0 + 1, first.len()]; - let mut res = Array2::::zeros(shape); - res.slice_mut(s![0, ..]).assign(&first); - for (i, s) in iter.enumerate() { - res.slice_mut(s![i + 1, ..]).assign(&s); + let mut arr = iter.into_iter().collect::>(); + let mut out = arr.pop().unwrap(); + for i in arr { + out = concatenate!(Axis(axis), out, i); } - res + out } - #[cfg(feature = "alloc")] + /// stack a 1D array into a 2D array by stacking them horizontally. pub fn hstack(iter: impl IntoIterator>) -> Array2 where - T: Clone + Num, + T: Clone + Zero, { let iter = Vec::from_iter(iter); let mut res = Array2::::zeros((iter.first().unwrap().len(), iter.len())); @@ -146,11 +149,10 @@ pub(crate) mod stack { } res } - #[cfg(feature = "alloc")] /// stack a 1D array into a 2D array by stacking them vertically. pub fn vstack(iter: impl IntoIterator>) -> Array2 where - T: Clone + Num, + T: Clone + Zero, { let iter = Vec::from_iter(iter); let mut res = Array2::::zeros((iter.len(), iter.first().unwrap().len())); diff --git a/core/tests/fft.rs b/core/tests/fft.rs deleted file mode 100644 index 3a38521b..00000000 --- a/core/tests/fft.rs +++ /dev/null @@ -1,125 +0,0 @@ -/* - Appellation: fft - Contrib: FL03 -*/ -use concision_core::utils::fft::*; - -use approx::assert_abs_diff_eq; -use lazy_static::lazy_static; -use num_complex::{Complex, ComplexFloat}; -use num_traits::Float; - -const EPSILON: f64 = 1e-6; - -lazy_static! { - static ref EXPECTED_RFFT: Vec> = vec![ - Complex { re: 28.0, im: 0.0 }, - Complex { re: -4.0, im: 0.0 }, - Complex { - re: -4.0, - im: 1.6568542494923806 - }, - Complex { - re: -4.0, - im: 4.000000000000001 - }, - Complex { - re: -3.999999999999999, - im: 9.656854249492381 - } - ]; -} - -fn handle(data: Vec) -> Vec -where - S: ComplexFloat, - T: Copy + Float, -{ - let tmp = { - let mut inner = data - .iter() - .cloned() - .filter(|i| i.im() > T::zero()) - .map(|i| i.conj()) - .collect::>(); - inner.sort_by(|a, b| a.im().partial_cmp(&b.im()).unwrap()); - inner - }; - let mut out = data.clone(); - out.sort_by(|a, b| a.re().partial_cmp(&b.re()).unwrap()); - out.sort_by(|a, b| a.im().partial_cmp(&b.im()).unwrap()); - out.extend(tmp); - out -} - -#[test] -#[ignore = "Needs to be fixed"] -fn test_rfft() { - let polynomial = (0..8).map(|i| i as f64).collect::>(); - let plan = FftPlan::new(polynomial.len()).build(); - println!("Function Values: {:?}\nPlan: {:?}", &polynomial, &plan); - let fft = rfft(&polynomial, &plan); - let res = handle(fft.clone()); - assert!(fft.len() == EXPECTED_RFFT.len()); - for (x, y) in fft.iter().zip(EXPECTED_RFFT.iter()) { - assert_abs_diff_eq!(x.re(), y.re()); - assert_abs_diff_eq!(x.im(), y.im()); - } - let plan = FftPlan::new(fft.len()).build(); - let _ifft = dbg!(irfft(&res, &plan)); - // for (x, y) in ifft.iter().zip(polynomial.iter()) { - // assert_abs_diff_eq!(*x, *y, epsilon = EPSILON); - // } -} - -#[test] -fn small_polynomial_returns_self() { - let polynomial = vec![1.0f64, 1.0, 0.0, 2.5]; - let permutation = FftPlan::new(polynomial.len()).build(); - let fft = fft(&polynomial, &permutation); - let ifft = ifft(&fft, &permutation) - .into_iter() - .map(|i| i.re()) - .collect::>(); - for (x, y) in ifft.iter().zip(polynomial.iter()) { - assert_abs_diff_eq!(*x, *y, epsilon = EPSILON); - } -} - -#[test] -fn square_small_polynomial() { - let mut polynomial = vec![1.0f64, 1.0, 0.0, 2.0]; - polynomial.append(&mut vec![0.0; 4]); - let plan = FftPlan::new(polynomial.len()).build(); - let mut fft = fft(&polynomial, &plan); - fft.iter_mut().for_each(|num| *num *= *num); - let ifft = ifft(&fft, &plan) - .into_iter() - .map(|i| i.re()) - .collect::>(); - let expected = [1.0, 2.0, 1.0, 4.0, 4.0, 0.0, 4.0, 0.0, 0.0]; - for (x, y) in ifft.iter().zip(expected.iter()) { - assert_abs_diff_eq!(*x, *y, epsilon = EPSILON); - } -} - -#[test] -#[ignore] -fn square_big_polynomial() { - // This test case takes ~1050ms on my machine in unoptimized mode, - // but it takes ~70ms in release mode. - let n = 1 << 17; // ~100_000 - let mut polynomial = vec![1.0f64; n]; - polynomial.append(&mut vec![0.0f64; n]); - let permutation = FftPlan::new(polynomial.len()).build(); - let mut fft = fft(&polynomial, &permutation); - fft.iter_mut().for_each(|num| *num *= *num); - let ifft = irfft(&fft, &permutation) - .into_iter() - .map(|i| i.re()) - .collect::>(); - let expected = (0..((n << 1) - 1)).map(|i| std::cmp::min(i + 1, (n << 1) - 1 - i) as f64); - for (&x, y) in ifft.iter().zip(expected) { - assert_abs_diff_eq!(x, y); - } -} diff --git a/concision/tests/simple.rs b/core/tests/models.rs similarity index 55% rename from concision/tests/simple.rs rename to core/tests/models.rs index c088d3c3..de132be0 100644 --- a/concision/tests/simple.rs +++ b/core/tests/models.rs @@ -1,13 +1,16 @@ /* - appellation: simple - authors: @FL03 + Appellation: models + Created At: 2025.12.07:11:02:49 + Contrib: @FL03 */ -use cnc::models::ex::sample::TestModel; -use cnc::{Model, ModelFeatures, StandardModelConfig}; +#![allow(unused_mut)] + +use concision_core::ex::sample::TestModel; +use concision_core::{Model, ModelFeatures, StandardModelConfig}; use ndarray::prelude::*; #[test] -fn test_simple_model() -> anyhow::Result<()> { +fn test_simple_model() { let mut config = StandardModelConfig::new() .with_epochs(1000) .with_batch_size(32); @@ -17,17 +20,20 @@ fn test_simple_model() -> anyhow::Result<()> { // define the model features let features = ModelFeatures::deep(3, 9, 1, 8); // initialize the model with the given features and configuration - let model = TestModel::::new(config, features); + let mut model = TestModel::::new(config, features); + #[cfg(feature = "rand")] + { + model = model.init(); + } // initialize some input data let input = Array1::linspace(1.0, 9.0, model.layout().input()); - let expected = Array1::from_elem(model.layout().output(), 0.5); - // forward the input through the model - let output = model.predict(&input).expect("prediction failed"); + let output = model.predict(&input); // verify the output shape assert_eq!(output.dim(), (features.output())); // compare the results to what we expected - assert_eq!(output, expected); - - Ok(()) + #[cfg(not(feature = "rand"))] + { + assert_eq!(output, Array1::from_elem(model.layout().output(), 0.5)); + } } diff --git a/core/tests/params.rs b/core/tests/params.rs deleted file mode 100644 index fc591f97..00000000 --- a/core/tests/params.rs +++ /dev/null @@ -1,20 +0,0 @@ -/* - Appellation: params - Contrib: @FL03 -*/ -use concision_params::Params; - -use approx::assert_abs_diff_eq; -use ndarray::prelude::*; - -#[test] -fn test_params_forward() { - let params = Params::::ones((3, 4)); - let input = array![1.0, 2.0, 3.0]; - // should be of shape 4: (d_in, d_out).t() * (d_in,) + (d_out,) - let output = params.forward(&input).expect("forward-pass failed"); - assert_eq!(output.dim(), 4); - // output should be: $W.t() * x + b = [7.0, 7.0, 7.0, 7.0]$ - // where W = ones(3, 4) and b = ones(4) - assert_abs_diff_eq!(output, array![7.0, 7.0, 7.0, 7.0], epsilon = 1e-3); -} diff --git a/data/Cargo.toml b/data/Cargo.toml index 6fb68f9c..fd9325e6 100644 --- a/data/Cargo.toml +++ b/data/Cargo.toml @@ -1,26 +1,34 @@ [package] -build = "build.rs" -description = "this crate provides additional tools for working with datasets" -name = "concision-data" authors.workspace = true +build = "build.rs" categories.workspace = true +description = "this crate provides additional tools for working with datasets" edition.workspace = true homepage.workspace = true keywords.workspace = true license.workspace = true +name = "concision-data" readme.workspace = true repository.workspace = true rust-version.workspace = true version.workspace = true +[package.metadata.docs.rs] +all-features = false +features = ["full"] +rustc-args = ["--cfg", "docsrs"] + +[package.metadata.release] +no-dev-version = true +tag-name = "v{{version}}" + [lib] -crate-type = ["cdylib","rlib"] bench = false +crate-type = ["cdylib", "rlib"] doc = true doctest = true test = true -# ************* [Unit Tests] ************* [[test]] name = "default" @@ -38,12 +46,12 @@ thiserror = { workspace = true } approx = { optional = true, workspace = true } ndarray = { workspace = true } num = { workspace = true } -num-traits = { workspace = true } num-complex = { optional = true, workspace = true } +num-traits = { workspace = true } # concurrency & parallelism rayon = { optional = true, workspace = true } # networking -reqwest = { optional = true, workspace = true } +reqwest = { optional = true, version = "0.12" } # data & serialization serde = { optional = true, workspace = true } serde_json = { optional = true, workspace = true } @@ -52,121 +60,109 @@ tracing = { optional = true, workspace = true } [features] default = [ - "std", + "std", ] full = [ - "default", - "approx", - "complex", - "json", - "rand", - "serde", - "tracing", + "approx", + "complex", + "default", + "json", + "rand", + "serde", + "tracing", ] nightly = [ - "concision-core/nightly", + "concision-core/nightly", ] # ************* [FF:Features] ************* loader = [ - "json", - "reqwest", + "json", + "reqwest", ] # ************* [FF:Environments] ************* std = [ - "alloc", - "concision-core/std", - "ndarray/std", - "num/std", - "num-complex?/std", - "serde?/std", - "serde_json?/std", - "tracing?/std", - "variants/std", + "alloc", + "concision-core/std", + "ndarray/std", + "num-complex?/std", + "num/std", + "serde?/std", + "serde_json?/std", + "tracing?/std", + "variants/std", ] wasi = [ - "concision-core/wasi", + "concision-core/wasi", ] wasm = [ - "concision-core/wasm", + "concision-core/wasm", ] # ************* [FF:Dependencies] ************* alloc = [ - "concision-core/alloc", - "num/alloc", - "serde?/alloc", - "variants/alloc", + "concision-core/alloc", + "num/alloc", + "serde?/alloc", + "variants/alloc", ] approx = [ - "concision-core/approx", - "dep:approx", - "ndarray/approx", + "concision-core/approx", + "dep:approx", + "ndarray/approx", ] blas = [ - "concision-core/blas", - "ndarray/blas", + "concision-core/blas", + "ndarray/blas", ] complex = [ - "dep:num-complex", - "concision-core/complex", + "concision-core/complex", + "dep:num-complex", ] json = [ - "alloc", - "serde", - "serde_json", - "concision-core/json", - "reqwest?/json", + "alloc", + "concision-core/json", + "reqwest?/json", + "serde", + "serde_json", ] rayon = [ - "concision-core/rayon", - "dep:rayon", + "concision-core/rayon", + "dep:rayon", ] rand = [ - "concision-core/rand", - "rng", + "concision-core/rand", + "rng", ] rng = [ - "concision-core/rng", + "concision-core/rng", ] reqwest = ["dep:reqwest"] serde = [ - "dep:serde", - "concision-core/serde", - "ndarray/serde", - "num/serde", - "num-complex?/serde", + "concision-core/serde", + "dep:serde", + "ndarray/serde", + "num-complex?/serde", + "num/serde", ] serde_json = ["dep:serde_json"] tracing = [ - "concision-core/tracing", - "dep:tracing", + "concision-core/tracing", + "dep:tracing", ] - -# ********* [Metadata] ********* -[package.metadata.docs.rs] -all-features = false -doc-scrape-examples = true -features = ["full"] -rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" - -[package.metadata.release] -no-dev-version = true -tag-name = "{{version}}" diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 8c4354e7..4b5e1a7f 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -18,11 +18,10 @@ version.workspace = true all-features = false features = ["default"] rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" [package.metadata.release] no-dev-version = true -tag-name = "{{version}}" +tag-name = "v{{version}}" [lib] bench = false @@ -36,10 +35,9 @@ proc-macro2 = "1" quote = "1" syn = { features = ["full"], version = "2" } - [features] default = [] nightly = [ - "proc-macro2/nightly" + "proc-macro2/nightly", ] diff --git a/derive/src/attrs.rs b/derive/src/attrs.rs new file mode 100644 index 00000000..4d45ac33 --- /dev/null +++ b/derive/src/attrs.rs @@ -0,0 +1,16 @@ +/* + Appellation: attrs + Created At: 2025.12.07:11:51:04 + Contrib: @FL03 +*/ +pub use self::{config::ConfigAttr, params::ParamsAttr}; + +mod config; +mod params; + +/// custom attributes for the concision derive macro +#[derive(Clone, Debug, Default)] +pub struct CncAttr { + pub config: Option, + pub params: Option, +} diff --git a/derive/src/attrs/config.rs b/derive/src/attrs/config.rs new file mode 100644 index 00000000..c8eb7738 --- /dev/null +++ b/derive/src/attrs/config.rs @@ -0,0 +1,11 @@ +/* + Appellation: attrs + Contrib: FL03 +*/ +use syn::Ident; + +#[derive(Clone, Debug, Default)] +pub struct ConfigAttr { + pub name: Option, + pub format: Option, +} diff --git a/derive/src/attrs/mod.rs b/derive/src/attrs/mod.rs deleted file mode 100644 index 8b137891..00000000 --- a/derive/src/attrs/mod.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/derive/src/attrs/attr.rs b/derive/src/attrs/params.rs similarity index 74% rename from derive/src/attrs/attr.rs rename to derive/src/attrs/params.rs index 7b4e4330..d7888b7b 100644 --- a/derive/src/attrs/attr.rs +++ b/derive/src/attrs/params.rs @@ -4,10 +4,6 @@ */ use syn::Ident; -pub struct ScsysAttr { - pub params: Option, -} - #[derive(Clone, Debug, Default)] pub struct ParamsAttr { pub name: Option, diff --git a/derive/src/params.rs b/derive/src/impls.rs similarity index 50% rename from derive/src/params.rs rename to derive/src/impls.rs index 10f63fcf..c19c760e 100644 --- a/derive/src/params.rs +++ b/derive/src/impls.rs @@ -1,15 +1,30 @@ /* - Appellation: params - Contrib: FL03 + Appellation: impls + Created At: 2025.12.07:11:55:02 + Contrib: @FL03 */ -pub mod keys; +mod impl_config; +mod impl_keys; use proc_macro2::TokenStream; use quote::{format_ident, quote}; use syn::{Data, DataStruct, DeriveInput}; -pub fn impl_params(input: &DeriveInput) -> TokenStream { +pub fn impl_config(DeriveInput { ident, data, .. }: &DeriveInput) -> TokenStream { + // ensure the target object is a struct + let out = match &data { + Data::Struct(s) => impl_config::derive_config_from_struct(&s, &ident), + _ => panic!("Only structs are supported"), + }; + + // Combine the generated code + quote! { + #out + } +} + +pub fn impl_keys(input: &DeriveInput) -> TokenStream { // Get the name of the struct let struct_name = &input.ident; let store_name = format_ident!("{}Key", struct_name); @@ -21,7 +36,7 @@ pub fn impl_params(input: &DeriveInput) -> TokenStream { Data::Struct(s) => { let DataStruct { fields, .. } = s; - keys::generate_keys(fields, &store_name) + impl_keys::generate_keys(fields, &store_name) } _ => panic!("Only structs are supported"), }; diff --git a/derive/src/impls/impl_config.rs b/derive/src/impls/impl_config.rs new file mode 100644 index 00000000..6c350d10 --- /dev/null +++ b/derive/src/impls/impl_config.rs @@ -0,0 +1,37 @@ +/* + Appellation: impl_config + Created At: 2025.12.07:11:54:21 + Contrib: @FL03 +*/ +use proc_macro2::TokenStream; +use quote::quote; +use syn::{DataStruct, Fields, FieldsNamed, FieldsUnnamed, Ident}; + +pub fn derive_config_from_struct( + DataStruct { fields, .. }: &DataStruct, + name: &Ident, +) -> TokenStream { + match fields { + Fields::Named(inner) => handle_named(inner, name), + Fields::Unnamed(inner) => handle_unnamed(inner, name), + _ => panic!("Unit fields aren't currently supported."), + } +} + +fn handle_named(_fields: &FieldsNamed, name: &Ident) -> TokenStream { + // let FieldsNamed { named, .. } = fields; + + quote! { + impl #name { + + } + } +} + +fn handle_unnamed(_fields: &FieldsUnnamed, name: &Ident) -> TokenStream { + quote! { + impl #name { + + } + } +} diff --git a/derive/src/params/keys.rs b/derive/src/impls/impl_keys.rs similarity index 100% rename from derive/src/params/keys.rs rename to derive/src/impls/impl_keys.rs diff --git a/derive/src/lib.rs b/derive/src/lib.rs index d1a9ebec..99ffbb9e 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -7,7 +7,6 @@ //! ## Overview //! //! -#![crate_name = "concision_derive"] #![crate_type = "proc-macro"] extern crate proc_macro; @@ -15,20 +14,35 @@ extern crate quote; extern crate syn; pub(crate) mod ast; +#[allow(dead_code)] pub(crate) mod attrs; -pub(crate) mod params; +pub(crate) mod impls; pub(crate) mod utils; use proc_macro::TokenStream; use syn::{DeriveInput, parse_macro_input}; +/// The [`Configuration`] derive macro generates configuration-related code for a given struct, +/// streamlining the process of creating compatible configuration spaces within the concision +/// framework. +#[proc_macro_derive(Configuration, attributes(config))] +pub fn configuration(input: TokenStream) -> TokenStream { + // Parse the input tokens into a syntax tree + let input = parse_macro_input!(input as DeriveInput); + + let res = impls::impl_config(&input); + + // Return the generated code as a TokenStream + res.into() +} + /// This macro generates a parameter struct and an enum of parameter keys. -#[proc_macro_derive(Keyed, attributes(param))] +#[proc_macro_derive(Keyed, attributes(keys))] pub fn keyed(input: TokenStream) -> TokenStream { // Parse the input tokens into a syntax tree let input = parse_macro_input!(input as DeriveInput); - let res = params::impl_params(&input); + let res = impls::impl_keys(&input); // Return the generated code as a TokenStream res.into() diff --git a/ext/Cargo.toml b/ext/Cargo.toml index d0cbc475..ac8ea512 100644 --- a/ext/Cargo.toml +++ b/ext/Cargo.toml @@ -1,50 +1,53 @@ [package] -build = "build.rs" +build = "build.rs" description = "this crate implements additional models using the concision framework" -name = "concision-ext" - -authors.workspace = true -categories.workspace = true -edition.workspace = true -homepage.workspace = true -keywords.workspace = true -license.workspace = true -readme.workspace = true -repository.workspace = true +name = "concision-ext" + +authors.workspace = true +categories.workspace = true +edition.workspace = true +homepage.workspace = true +keywords.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true rust-version.workspace = true -version.workspace = true +version.workspace = true [package.metadata.docs.rs] all-features = false -features = ["full"] -rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" +doc-scrape-examples = true +features = ["full"] +rustc-args = ["--cfg", "docsrs"] [package.metadata.release] no-dev-version = true -tag-name = "{{version}}" +tag-name = "v{{version}}" [lib] -bench = false +bench = false crate-type = ["cdylib", "rlib"] -doc = true -doctest = true -test = true +doctest = false +test = true [[example]] name = "attention" required-features = ["attention", "rand", "std"] +[[example]] +name = "snn" +required-features = ["approx", "snn", "std"] + [[test]] name = "attention" required-features = ["approx", "attention", "rand", "std"] [[test]] -name = "snn" +name = "snn" required-features = ["approx", "snn", "std"] [dependencies] -concision = { features = ["neural"], workspace = true } +concision = { workspace = true } # custom variants = { workspace = true } # error handling @@ -54,33 +57,34 @@ rayon = { optional = true, workspace = true } # data-structures ndarray = { workspace = true } # mathematics -approx = { optional = true, workspace = true } -num = { workspace = true } +approx = { optional = true, workspace = true } +num = { workspace = true } num-complex = { optional = true, workspace = true } -num-traits = { workspace = true } -rustfft = { optional = true, workspace = true } +num-traits = { workspace = true } +rustfft = { optional = true, workspace = true } # serialization -serde = { optional = true, workspace = true } +serde = { optional = true, workspace = true } serde_derive = { optional = true, workspace = true } -serde_json = { optional = true, workspace = true } +serde_json = { optional = true, workspace = true } # logging tracing = { optional = true, workspace = true } [dev-dependencies] -lazy_static = { workspace = true } +lazy_static = { workspace = true } tracing-subscriber = { features = ["std"], workspace = true } [features] default = ["attention", "std"] full = [ - "complex", "default", + "approx", + "complex", "json", + "models", "rand", "serde", "tracing", - "models", ] nightly = ["concision/nightly"] @@ -92,6 +96,8 @@ signal = ["complex", "rustfft"] # ********* [FF:Models] ********* models = [ + "s4", + "snn", "transformer", ] @@ -140,7 +146,11 @@ alloc = [ "variants/alloc", ] -approx = ["concision/approx", "dep:approx", "ndarray/approx"] +approx = [ + "dep:approx", + "concision/approx", + "ndarray/approx", +] blas = ["concision/blas", "ndarray/blas"] @@ -155,9 +165,9 @@ rng = ["concision/rng"] rustfft = ["dep:rustfft"] serde = [ - "concision/serde", "dep:serde", "dep:serde_derive", + "concision/serde", "ndarray/serde", "num-complex?/serde", "num/serde", @@ -165,4 +175,7 @@ serde = [ serde_json = ["dep:serde_json"] -tracing = ["concision/tracing", "dep:tracing"] +tracing = [ + "dep:tracing", + "concision/tracing", +] diff --git a/ext/examples/attention.rs b/ext/examples/attention.rs index 3d3fca29..76bd16f8 100644 --- a/ext/examples/attention.rs +++ b/ext/examples/attention.rs @@ -8,8 +8,7 @@ use concision_ext::attention::{Qkv, ScaledDotProductAttention}; fn main() -> anyhow::Result<()> { tracing_subscriber::fmt() .with_max_level(tracing::Level::TRACE) - .with_target(false) - .without_time() + .with_timer(tracing_subscriber::fmt::time::Uptime::default()) .init(); let (m, n) = (7, 10); let qkv = Qkv::::ones((m, n)); diff --git a/ext/examples/snn.rs b/ext/examples/snn.rs new file mode 100644 index 00000000..19f8bd3b --- /dev/null +++ b/ext/examples/snn.rs @@ -0,0 +1,86 @@ +/* + Appellation: snn + Created At: 2025.12.08:15:27:07 + Contrib: @FL03 +*/ +//! Minimal demonstration of neuron usage. Simulates a neuron for `t_sim` ms with dt, +//! injects a constant external current `i_ext`, and injects discrete synaptic events at specified times. +use concision_ext::snn::{LIFNeuron, SynapticEvent}; + +fn main() -> anyhow::Result<()> { + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .with_timer(tracing_subscriber::fmt::time::Uptime::default()) + .init(); + // Simulation parameters + let dt = 0.1; // ms + let t_sim = 5000.0; // ms + let steps = (t_sim / dt) as usize; + + // Create neuron with defaults + let mut neuron = LIFNeuron::default(); + + // Example external current (constant) + // Increase drive so steady-state v can reach threshold (v_rest + R*i_ext > v_thresh). + // With default params: v_rest = -65, v_thresh = -50 → need i_ext >= 15. + let i_ext = 16.0; // stronger constant drive to induce spiking + + // Example presynaptic spike times (ms) and weights + let presyn_spikes: Vec<(f64, f64)> = + vec![(50.0, 2.0), (100.0, 1.5), (150.0, 2.2), (300.0, 3.0)]; + + // Convert into an index-able event list + let mut events: Vec> = vec![Vec::new(); steps + 1]; + for (t_spike, weight) in presyn_spikes { + let idx = (t_spike / dt).round() as isize; + if idx >= 0 && (idx as usize) < events.len() { + events[idx as usize].push(SynapticEvent { weight }); + } + } + + // Simulation loop + let mut spike_times = Vec::::new(); + for step in 0..steps { + let t = step as f64 * dt; + + // deliver presynaptic events scheduled for this time step + for ev in &events[step] { + neuron.apply_spike(ev.weight); + } + + // step the neuron + let res = neuron.step(dt, i_ext); + + if res.is_spiked() { + spike_times.push(t); + // print the pre-spike membrane potential from the step result + println!( + "Spike at {:.3} ms (pre-spike v = {:.3})", + t, + res.membrane_potential() + ); + } + + // optionally, record v, w, s for analysis (omitted here for brevity) + let _v = neuron.membrane_potential(); + let _w = neuron.adaptation(); + let _s = neuron.synaptic_state(); + + // small example of printing membrane potential every 50 ms + if step % ((50.0 / dt) as usize) == 0 { + // show the step result potential (pre-reset when a spike occurred) + tracing::info!( + "t={:.1} ms, v={:.3} mV, w={:.3}, s={:.3}", + t, + res.membrane_potential(), + _w, + _s + ); + } + } + + tracing::info!("Total spikes: {}", spike_times.len()); + tracing::info!("Spike times: {:?}", spike_times); + + Ok(()) +} diff --git a/ext/src/attention/fft.rs b/ext/src/attention/fft.rs index 92299c4d..01d68f9c 100644 --- a/ext/src/attention/fft.rs +++ b/ext/src/attention/fft.rs @@ -87,7 +87,7 @@ impl FftAttention { level = "trace", ) )] - pub fn forward(&self, input: &X) -> cnc::Result + pub fn forward(&self, input: &X) -> Y where Self: Forward, { @@ -111,12 +111,16 @@ where { type Output = Array1; - fn forward(&self, input: &ArrayBase) -> cnc::Result { + fn forward(&self, input: &ArrayBase) -> Self::Output { let seq_len = input.dim(); let n = A::from_usize(seq_len).unwrap(); if seq_len == 0 { - return Err(cnc::params::ParamsError::InvalidInputShape.into()); + return Err(cnc::params::ParamsError::MismatchedShapes { + expected: &[1], + found: 0, + } + .into()); } // Create FFT planner @@ -225,7 +229,7 @@ where let (seq_len, feature_dim) = input.dim(); if seq_len == 0 { - return Err(cnc::params::ParamsError::InvalidInputShape.into()); + return Err(anyhow::anyhow!("Input sequence length cannot be zero")); } // Create FFT planner diff --git a/ext/src/attention/qkv.rs b/ext/src/attention/qkv.rs index 19b48829..ef1c3cef 100644 --- a/ext/src/attention/qkv.rs +++ b/ext/src/attention/qkv.rs @@ -25,7 +25,11 @@ where D: Dimension, S: RawData, { - pub fn new(query: ArrayBase, key: ArrayBase, value: ArrayBase) -> Self { + pub fn new( + query: ArrayBase, + key: ArrayBase, + value: ArrayBase, + ) -> Self { Self { query, key, value } } pub fn from_elem>(shape: Sh, elem: A) -> Self @@ -128,11 +132,10 @@ where { type Output = Y; - fn forward(&self, input: &X) -> Option { + fn forward(&self, input: &X) -> Y { let query = input.dot(&self.query); let key = input.dot(&self.key); let value = input.dot(&self.value); - let output = query + key + value; - Some(output) + query + key + value } } diff --git a/ext/src/kan/model.rs b/ext/src/kan/model.rs index ec21c8b1..5a6a11d1 100644 --- a/ext/src/kan/model.rs +++ b/ext/src/kan/model.rs @@ -2,11 +2,11 @@ appellation: model authors: @FL03 */ +use cnc::config::StandardModelConfig; +use cnc::prelude::{DeepModelParams, Model, ModelFeatures}; -use cnc::nn::{DeepModelParams, Model, ModelFeatures, StandardModelConfig}; #[cfg(feature = "rand")] -use cnc::rand_distr; - +use cnc::init::rand_distr::{Distribution, StandardNormal}; use num_traits::{Float, FromPrimitive}; #[derive(Clone, Debug)] @@ -108,8 +108,8 @@ impl Model for KanModel { &mut self.config } - fn layout(&self) -> ModelFeatures { - self.features + fn layout(&self) -> &ModelFeatures { + &self.features } fn params(&self) -> &DeepModelParams { diff --git a/ext/src/lib.rs b/ext/src/lib.rs index eb493376..5f9c97b3 100644 --- a/ext/src/lib.rs +++ b/ext/src/lib.rs @@ -16,6 +16,7 @@ #[cfg(feature = "alloc")] extern crate alloc; +extern crate concision as cnc; #[cfg(feature = "attention")] pub mod attention; diff --git a/ext/src/s4/model.rs b/ext/src/s4/model.rs index 1144521e..142df60c 100644 --- a/ext/src/s4/model.rs +++ b/ext/src/s4/model.rs @@ -2,11 +2,11 @@ appellation: model authors: @FL03 */ +use cnc::config::StandardModelConfig; +use cnc::prelude::{DeepModelParams, Model, ModelFeatures}; -use cnc::nn::{DeepModelParams, Model, ModelFeatures, StandardModelConfig}; #[cfg(feature = "rand")] -use cnc::rand_distr; - +use cnc::init::rand_distr::{Distribution, StandardNormal}; use num_traits::{Float, FromPrimitive}; #[derive(Clone, Debug)] @@ -90,7 +90,7 @@ where pub fn init(self) -> Self where T: 'static + Float + FromPrimitive, - rand_distr::StandardNormal: rand_distr::Distribution, + StandardNormal: Distribution, { let params = DeepModelParams::glorot_normal(self.features()); S4Model { params, ..self } @@ -109,8 +109,8 @@ impl Model for S4Model { &mut self.config } - fn layout(&self) -> ModelFeatures { - self.features + fn layout(&self) -> &ModelFeatures { + &self.features } fn params(&self) -> &DeepModelParams { diff --git a/ext/src/snn/mod.rs b/ext/src/snn/mod.rs index 84b78032..017c38f4 100644 --- a/ext/src/snn/mod.rs +++ b/ext/src/snn/mod.rs @@ -1,18 +1,54 @@ -/* - appellation: snn - authors: @FL03 -*/ //! Spiking neural networks (SNNs) for the [`concision`](https://crates.io/crates/concision) machine learning framework. //! //! ## References //! //! - [Deep Learning in Spiking Neural Networks](https://arxiv.org/abs/1804.08150) //! +//! ## Background +//! +//! Spiking Neural Networks (SNNs) are a class of artificial neural networks that more closely +//! mimic the behavior of biological neurons compared to traditional artificial neural +//! networks. In SNNs, neurons communicate by sending discrete spikes (or action potentials) +//! rather than continuous values. This allows SNNs to capture temporal dynamics and +//! event-driven processing, making them suitable for tasks that involve time-series data +//! or require low-power computation. +//! +//! ### Model (forward-Euler integration; units are arbitrary but consistent): +//! +//! ```math +//! \tau_m * \frac{dv}{dt} = -(v - v_{rest}) + R*(I_{ext} + I_{syn}) - \omega +//! ``` +//! +//! ```math +//! \tau_w * \frac{d\omega}{dt} = -\omega +//! ``` +//! +//! ```math +//! \tau_s * \frac{ds}{dt} = -s +//! ``` +//! +//! where: +//! - $`v`$: membrane potential +//! - $`\omega`$: adaptation variable +//! - $`s`$: synaptic variable representing total synaptic current +//! +//! If we allow the spike to be represented as $`\delta`$, then: +//! +//! ```math +//! v\geq{v_{thresh}}\rightarrow{\delta},v\leftarrow{v_{reset}},\omega\mathrel{+}=b +//! ``` +//! +//! where $`b` is the adaptation increment added on spike. The synaptic current is given by: +//! +//! ```math +//! I_{syn} = s +//! ``` #[doc(inline)] -pub use self::{model::*, neuron::*, types::*}; +pub use self::{model::*, neurons::*, types::*, utils::*}; mod model; -mod neuron; +mod neurons; +mod utils; pub mod types { //! Types for spiking neural networks @@ -25,6 +61,6 @@ pub mod types { pub(crate) mod prelude { pub use super::model::*; - pub use super::neuron::*; + pub use super::neurons::*; pub use super::types::*; } diff --git a/ext/src/snn/model.rs b/ext/src/snn/model.rs index fc0d96dc..0f33f231 100644 --- a/ext/src/snn/model.rs +++ b/ext/src/snn/model.rs @@ -2,11 +2,11 @@ appellation: model authors: @FL03 */ +use cnc::config::StandardModelConfig; +use cnc::prelude::{DeepModelParams, Model, ModelFeatures}; -use cnc::nn::{DeepModelParams, Model, ModelFeatures, StandardModelConfig}; #[cfg(feature = "rand")] -use cnc::rand_distr; - +use cnc::init::rand_distr::{Distribution, StandardNormal}; use num_traits::{Float, FromPrimitive}; #[derive(Clone, Debug)] @@ -53,42 +53,38 @@ impl SpikingNeuralNetwork { &mut self.params } /// set the current configuration and return a mutable reference to the model - pub fn set_config(&mut self, config: StandardModelConfig) -> &mut Self { + pub fn set_config(&mut self, config: StandardModelConfig) { self.config = config; - self } /// set the current features and return a mutable reference to the model - pub fn set_features(&mut self, features: ModelFeatures) -> &mut Self { + pub const fn set_features(&mut self, features: ModelFeatures) { self.features = features; - self } /// set the current parameters and return a mutable reference to the model - pub fn set_params(&mut self, params: DeepModelParams) -> &mut Self { + pub fn set_params(&mut self, params: DeepModelParams) { self.params = params; - self } + #[inline] /// consumes the current instance to create another with the given configuration pub fn with_config(self, config: StandardModelConfig) -> Self { Self { config, ..self } } + #[inline] /// consumes the current instance to create another with the given features pub fn with_features(self, features: ModelFeatures) -> Self { Self { features, ..self } } + #[inline] /// consumes the current instance to create another with the given parameters pub fn with_params(self, params: DeepModelParams) -> Self { Self { params, ..self } } -} -impl SpikingNeuralNetwork -where - T: 'static + Float + FromPrimitive, -{ #[cfg(feature = "rand")] pub fn init(self) -> Self where - rand_distr::StandardNormal: rand_distr::Distribution, + T: 'static + Float + FromPrimitive, + StandardNormal: Distribution, { let params = DeepModelParams::glorot_normal(self.features()); SpikingNeuralNetwork { params, ..self } @@ -108,8 +104,8 @@ impl Model for SpikingNeuralNetwork { &mut self.config } - fn layout(&self) -> ModelFeatures { - self.features + fn layout(&self) -> &ModelFeatures { + &self.features } fn params(&self) -> &DeepModelParams { diff --git a/ext/src/snn/neuron.rs b/ext/src/snn/neuron.rs deleted file mode 100644 index 3fdc5395..00000000 --- a/ext/src/snn/neuron.rs +++ /dev/null @@ -1,311 +0,0 @@ -/* - Appellation: neuron - Created At: 2025.11.25:09:33:30 - Contrib: @FL03 -*/ -//! Single spiking neuron (LIF + adaptation + exponential synapse) example in pure Rust. -//! -//! ## Background -//! -//! Model (forward-Euler integration; units are arbitrary but consistent): -//! -//! ```math -//! \tau_m * \frac{dv}{dt} = -(v - v_{rest}) + R*(I_{ext} + I_{syn}) - \omega -//! ``` -//! -//! ```math -//! \tau_w * \frac{d\omega}{dt} = -\omega -//! ``` -//! -//! ```math -//! \tau_s * \frac{ds}{dt} = -s -//! ``` -//! -//! where: -//! - $`v`$: membrane potential -//! - $\omega$: adaptation variable -//! - $`s`$: synaptic variable representing total synaptic current -//! -//! If we allow the spike to be represented as $\delta$, then: -//! -//! ```math -//! v\geq{v_{thresh}}\rightarrow{\delta},v\leftarrow{v_{reset}},\omega\mathrel{+}=b -//! ``` -//! -//! and where `b` is the adaptation increment added on spike. -//! The synaptic current is given by: $I_{syn} = s$ -//! -//! The implementation is conservative with allocations and idiomatic Rust. -use super::types::{StepResult, SynapticEvent}; - -/// Leaky Integrate-and-Fire neuron with an adaptation term and exponential synaptic current. -/// -/// All fields are public for convenience in research workflows; in production you may want to -/// expose read-only getters and safe setters only. -#[derive(Clone)] -pub struct SpikingNeuron { - // ---- Parameters ---- - /// Membrane time constant `tau_m` (ms) - pub tau_m: f64, - /// Membrane resistance `R` (MΩ or arbitrary) - pub resistance: f64, - /// Resting potential `v_rest` (mV) - pub v_rest: f64, - /// Threshold potential `v_thresh` (mV) - pub v_thresh: f64, - /// Reset potential after spike `v_reset` (mV) - pub v_reset: f64, - - /// Adaptation time constant `tau_w` (ms) - pub tau_w: f64, - /// Adaptation increment added on spike `b` (same units as w/current) - pub b: f64, - - /// Synaptic time constant `tau_s` (ms) - pub tau_s: f64, - - // ---- State variables ---- - /// Membrane potential `v` - pub v: f64, - /// Adaptation variable `w` - pub w: f64, - /// Synaptic variable `s` representing total synaptic current - pub s: f64, - - // ---- Optional numerical safeguards ---- - /// Minimum allowed dt for integration (ms) - pub min_dt: f64, -} - -impl SpikingNeuron { - #[allow(clippy::should_implement_trait)] - /// Create a new neuron with common default parameters (units: ms and mV-like). - /// - /// Many fields are set to common neuroscience-like defaults but these are research parameters - /// and should be tuned for your experiments. - pub fn default() -> Self { - let tau_m = 20.0; // ms - let resistance = 1.0; // arbitrary - let v_rest = -65.0; // mV - let v_thresh = -50.0; // mV - let v_reset = -65.0; // mV - let tau_w = 200.0; // ms (slow adaptation) - let b = 0.5; // adaptation increment - let tau_s = 5.0; // ms (fast synapse) - Self { - tau_m, - resistance, - v_rest, - v_thresh, - v_reset, - tau_w, - b, - tau_s, - v: v_rest, - w: 0.0, - s: 0.0, - min_dt: 1e-6, - } - } - - /// Create a neuron with explicit parameters and initial state. - pub const fn new( - tau_m: f64, - resistance: f64, - v_rest: f64, - v_thresh: f64, - v_reset: f64, - tau_w: f64, - b: f64, - tau_s: f64, - initial_v: Option, - ) -> Self { - let v0 = if let Some(v_init) = initial_v { - v_init - } else { - v_rest - }; - Self { - tau_m, - resistance, - v_rest, - v_thresh, - v_reset, - tau_w, - b, - tau_s, - v: v0, - w: 0.0, - s: 0.0, - min_dt: 1e-6, - } - } - - /// Reset state variables (keeps parameters). - pub fn reset(&mut self) { - self.v = self.v_rest; - self.w = 0.0; - self.s = 0.0; - } - - /// Apply a presynaptic spike event to the neuron; this increments the synaptic variable `s` - /// by `weight` instantaneously (models delta spike arrival). - pub fn receive_spike(&mut self, weight: f64) { - self.s += weight; - } - - /// Integrate the neuron state forward by `dt` milliseconds using forward Euler. - /// - /// `i_ext` is an externally injected current (same units as `s`). - /// `dt` must be > 0. - pub fn step(&mut self, dt: f64, i_ext: f64) -> StepResult { - let dt = if dt <= 0.0 { - panic!("dt must be > 0") - } else { - dt.max(self.min_dt) - }; - - // synaptic current is represented by `s` - // ds/dt = -s / tau_s - let ds = -self.s / self.tau_s; - let s_next = self.s + dt * ds; - - // total synaptic current for this step (use current s, or average between s and s_next) - // we use s for explicit Euler consistency. - let i_syn = self.s; - - // membrane dv/dt = (-(v - v_rest) + R*(i_ext + i_syn) - w) / tau_m - let dv = - (-(self.v - self.v_rest) + self.resistance * (i_ext + i_syn) - self.w) / self.tau_m; - let v_next = self.v + dt * dv; - - // adaptation dw/dt = -w / tau_w - let dw = -self.w / self.tau_w; - let w_next = self.w + dt * dw; - - // Commit state tentatively - self.v = v_next; - self.w = w_next; - self.s = s_next; - - // Check for spike (simple threshold crossing) - if self.v >= self.v_thresh { - // spike: apply reset and adaptation increment - self.v = self.v_reset; - self.w += self.b; - StepResult { - spiked: true, - v: self.v, - } - } else { - StepResult { - spiked: false, - v: self.v, - } - } - } - - /// Get current membrane potential - pub fn membrane_potential(&self) -> f64 { - self.v - } - - /// Get current synaptic variable - pub fn synaptic_state(&self) -> f64 { - self.s - } - - /// Get adaptation variable - pub fn adaptation(&self) -> f64 { - self.w - } -} - -impl Default for SpikingNeuron { - fn default() -> Self { - let tau_m = 20.0; // ms - let resistance = 1.0; // arbitrary - let v_rest = -65.0; // mV - let v_thresh = -50.0; // mV - let v_reset = -65.0; // mV - let tau_w = 200.0; // ms (slow adaptation) - let b = 0.5; // adaptation increment - let tau_s = 5.0; // ms (fast synapse) - Self { - tau_m, - resistance, - v_rest, - v_thresh, - v_reset, - tau_w, - b, - tau_s, - v: v_rest, - w: 0.0, - s: 0.0, - min_dt: 1e-6, - } - } -} - -#[allow(dead_code)] -/// Minimal demonstration of neuron usage. Simulates a neuron for `t_sim` ms with dt, -/// injects a constant external current `i_ext`, and injects discrete synaptic events at specified times. -fn example() { - // Simulation parameters - let dt = 0.1; // ms - let t_sim = 500.0; // ms - let steps = (t_sim / dt) as usize; - - // Create neuron with defaults - let mut neuron = SpikingNeuron::default(); - - // Example external current (constant) - let i_ext = 1.8; // tune to see spiking (units consistent with resistance & s) - - // Example presynaptic spike times (ms) and weights - let presyn_spikes: Vec<(f64, f64)> = - vec![(50.0, 2.0), (100.0, 1.5), (150.0, 2.2), (300.0, 3.0)]; - - // Convert into an index-able event list - let mut events: Vec> = vec![Vec::new(); steps + 1]; - for (t_spike, weight) in presyn_spikes { - let idx = (t_spike / dt).round() as isize; - if idx >= 0 && (idx as usize) < events.len() { - events[idx as usize].push(SynapticEvent { weight }); - } - } - - // Simulation loop - let mut spike_times: Vec = Vec::new(); - for step in 0..steps { - let t = step as f64 * dt; - - // deliver presynaptic events scheduled for this time step - for ev in &events[step] { - neuron.receive_spike(ev.weight); - } - - // step the neuron - let res = neuron.step(dt, i_ext); - - if res.spiked { - spike_times.push(t); - // For debugging: print spike time - println!("Spike at {:.3} ms (v reset = {:.3})", t, neuron.v); - } - - // optionally, record v, w, s for analysis (omitted here for brevity) - let _v = neuron.membrane_potential(); - let _w = neuron.adaptation(); - let _s = neuron.synaptic_state(); - - // small example of printing membrane potential every 50 ms - if step % ((50.0 / dt) as usize) == 0 { - println!("t={:.1} ms, v={:.3} mV, w={:.3}, s={:.3}", t, _v, _w, _s); - } - } - - println!("Total spikes: {}", spike_times.len()); - println!("Spike times: {:?}", spike_times); -} diff --git a/ext/src/snn/neurons/lif.rs b/ext/src/snn/neurons/lif.rs new file mode 100644 index 00000000..eacbe16c --- /dev/null +++ b/ext/src/snn/neurons/lif.rs @@ -0,0 +1,214 @@ +/* + Appellation: neuron + Created At: 2025.11.25:09:33:30 + Contrib: @FL03 +*/ + +use crate::snn::StepResult; +use num_traits::{Float, FromPrimitive, NumAssign, Zero}; + +/// Leaky Integrate-and-Fire (LIF) neuron with an adaptation term and exponential synaptic +/// current. +/// +/// The neuron dynamics are governed by the following equations: +/// +/// ```math +/// \frac{dv}{dt} = \frac{-(v - v_{rest}) + R \cdot (i_{ext} + s) - w}{\tau_{m}} +/// ``` +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] +#[cfg_attr( + feature = "serde", + derive(serde::Deserialize, serde::Serialize), + serde(rename_all = "snake_case") +)] +pub struct LIFNeuron { + // ---- Parameters ---- + /// Membrane time constant $`\tau_{m}`$ (ms) + pub tau_m: T, + /// Membrane resistance `R` (MΩ or arbitrary) + pub resistance: T, + /// Resting potential $``v_{rest}`$ (mV) + pub v_rest: T, + /// Threshold potential $`v_{thresh}`$ (mV) + pub v_thresh: T, + /// Reset potential after spike $`v_{reset}`$ (mV) + pub v_reset: T, + + /// Adaptation time constant $`\tau_{w}`$ (ms) + pub tau_w: T, + /// Adaptation increment added on spike `b` (same units as w/current) + pub b: T, + + /// Synaptic time constant $`\tau_{s}`$ (ms) + pub tau_s: T, + + // ---- State variables ---- + /// Membrane potential `v` + pub v: T, + /// Adaptation variable `w` + pub w: T, + /// Synaptic variable `s` representing total synaptic current + pub s: T, + + /// Minimum allowed dt for integration (ms) + pub min_dt: T, +} + +impl LIFNeuron { + /// Create a neuron with explicit parameters and initial state. + pub fn new( + tau_m: T, + resistance: T, + v_rest: T, + v_thresh: T, + v_reset: T, + tau_w: T, + b: T, + tau_s: T, + initial_v: Option, + ) -> Self + where + T: Float + FromPrimitive, + { + let v0 = if let Some(v_init) = initial_v { + v_init + } else { + v_rest + }; + Self { + tau_m, + resistance, + v_rest, + v_thresh, + v_reset, + tau_w, + b, + tau_s, + v: v0, + w: T::zero(), + s: T::zero(), + min_dt: T::from_f32(1e-6).unwrap(), + } + } + /// returns a reference to the neuron's adaptation variable (`w`) + pub const fn adaptation(&self) -> &T { + &self.w + } + /// returns a reference to the membrane potential, `v`, of the neuron + pub const fn membrane_potential(&self) -> &T { + &self.v + } + /// returns a reference to the current value, or synaptic state, of the neuron (`s`) + pub const fn synaptic_state(&self) -> &T { + &self.s + } + /// returns a reference to the membrane time constant, `tau_m`, of the neuron + pub const fn tau_m(&self) -> &T { + &self.tau_m + } + /// returns a reference to the membrane resistance, `R`, of the neuron + pub const fn resistance(&self) -> &T { + &self.resistance + } + #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "trace"))] + /// Apply a presynaptic spike event to the neuron; this increments the synaptic variable `s` + /// by `weight` instantaneously (models delta spike arrival). + pub fn apply_spike(&mut self, weight: T) + where + T: NumAssign + Zero, + { + self.s += weight; + } + /// reset state variables (keeps parameters). + pub fn reset_state(&mut self) + where + T: Clone + Default, + { + self.v = self.v_rest.clone(); + self.w = T::default(); + self.s = T::default(); + } + #[cfg_attr(feature = "tracing", tracing::instrument(skip_all, level = "trace"))] + /// Integrate the neuron state forward by `dt` [ms] using forward Euler; the externally + /// applied current, `i_ext`, is added to the synaptic current `s` for the integration + /// step. Therefore it is important to maintain unitary consistency between `i_ext` and `s` + /// and to ensure that the provided `dt` is sufficiently small to avoid missing spikes, yet + /// still greater than 0 + /// + /// **Note**: This method checks for threshold crossing explicitly to avoid missing spikes + /// due to large `dt`. Additionally, if `dt` is less than `min_dt`, it is clamped to + /// `min_dt`. + pub fn step(&mut self, dt: T, i_ext: T) -> StepResult + where + T: Float + FromPrimitive + NumAssign, + { + let dt = if dt.is_sign_negative() { + panic!("dt must be > 0") + } else { + dt.max(self.min_dt) + }; + + // remember previous membrane potential for crossing detection + let v_prev = self.v; + + // synaptic current is represented by `s` + // ds/dt = -s / tau_s + let ds = -self.s / self.tau_s; + let s_next = self.s + dt * ds; + + // total synaptic current for this step (use current s, or average between s and s_next) + // we use s for explicit Euler consistency. + let i_syn = self.s; + + // membrane dv/dt = (-(v - v_rest) + R*(i_ext + i_syn) - w) / tau_m + let dv = + (-(self.v - self.v_rest) + self.resistance * (i_ext + i_syn) - self.w) / self.tau_m; + let v_next = self.v + dt * dv; + + // adaptation dw/dt = -w / tau_w + let dw = -self.w / self.tau_w; + let w_next = self.w + dt * dw; + + // Commit state tentatively + self.v = v_next; + self.w = w_next; + self.s = s_next; + + // Check for threshold crossing (explicit crossing test to avoid misses) + if v_prev < self.v_thresh && self.v >= self.v_thresh { + // spike: capture pre-reset potential if that is expected by StepResult consumers + let pre_spike_v = self.v; + // apply reset and adaptation increment + self.v = self.v_reset; + self.w += self.b; + StepResult { + spiked: true, + v: pre_spike_v, + } + } else { + StepResult { + spiked: false, + v: self.v, + } + } + } +} + +impl Default for LIFNeuron +where + T: Float + FromPrimitive, +{ + fn default() -> Self { + let tau_m = T::from_usize(20).unwrap(); // ms + let resistance = T::one(); // arbitrary + let v_rest = T::from_usize(65).unwrap().neg(); // mV + let v_thresh = T::from_usize(50).unwrap().neg(); // mV + let v_reset = T::from_usize(65).unwrap().neg(); // mV + let tau_w = T::from_usize(200).unwrap(); // ms (slow adaptation) + let b = T::from_f32(0.5).unwrap(); // adaptation increment + let tau_s = T::from_usize(5).unwrap(); // ms (fast synapse) + Self::new( + tau_m, resistance, v_rest, v_thresh, v_reset, tau_w, b, tau_s, None, + ) + } +} diff --git a/ext/src/snn/neurons/mod.rs b/ext/src/snn/neurons/mod.rs new file mode 100644 index 00000000..2ceb1744 --- /dev/null +++ b/ext/src/snn/neurons/mod.rs @@ -0,0 +1,9 @@ +/* + Appellation: neurons + Created At: 2025.12.08:17:26:07 + Contrib: @FL03 +*/ +#[doc(inline)] +pub use self::lif::LIFNeuron; + +pub mod lif; diff --git a/ext/src/snn/types/event.rs b/ext/src/snn/types/event.rs index b731ab78..7012f48b 100644 --- a/ext/src/snn/types/event.rs +++ b/ext/src/snn/types/event.rs @@ -4,8 +4,7 @@ Contrib: @FL03 */ -/// A simple synaptic event: weight added to synaptic variable `s` when it arrives. - +/// A synaptic event that modifies the synaptic variable `s` by an instantaneous weight. #[derive(Clone, Copy, Debug, Default, Eq, Hash, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serde", derive(serde::Deserialize, serde::Serialize))] pub struct SynapticEvent { @@ -14,8 +13,73 @@ pub struct SynapticEvent { } impl SynapticEvent { - /// Create a new SynapticEvent + /// returns a new instance of the `SynapticEvent` using the given weight pub const fn new(weight: T) -> Self { Self { weight } } + /// returns a reference to the weight + pub const fn weight(&self) -> &T { + &self.weight + } + /// returns a mutable reference to the weight of the synaptic event + pub const fn weight_mut(&mut self) -> &mut T { + &mut self.weight + } + /// [`replace`](core::mem::replace) the weight with a new value, returning the old value. + pub const fn replace_weight(&mut self, weight: T) -> T { + core::mem::replace(self.weight_mut(), weight) + } + /// sets the weight to a new value + pub fn set_weight(&mut self, weight: T) { + self.weight = weight; + } + /// [`swap`](core::mem::swap) the weight with the weight of another synaptic event + pub const fn swap_weight(&mut self, other: &mut SynapticEvent) { + core::mem::swap(self.weight_mut(), other.weight_mut()); + } + /// [`take`](core::mem::take) the weight, leaving the default value in its place. + pub fn take_weight(&mut self) -> T + where + T: Default, + { + core::mem::take(&mut self.weight) + } +} + +impl AsRef for SynapticEvent { + fn as_ref(&self) -> &T { + self.weight() + } +} + +impl AsMut for SynapticEvent { + fn as_mut(&mut self) -> &mut T { + self.weight_mut() + } +} + +impl core::borrow::Borrow for SynapticEvent { + fn borrow(&self) -> &T { + self.weight() + } +} + +impl core::borrow::BorrowMut for SynapticEvent { + fn borrow_mut(&mut self) -> &mut T { + self.weight_mut() + } +} + +impl core::ops::Deref for SynapticEvent { + type Target = T; + + fn deref(&self) -> &Self::Target { + self.weight() + } +} + +impl core::ops::DerefMut for SynapticEvent { + fn deref_mut(&mut self) -> &mut Self::Target { + self.weight_mut() + } } diff --git a/ext/src/snn/types/result.rs b/ext/src/snn/types/result.rs index e1a5779b..8e4b87c5 100644 --- a/ext/src/snn/types/result.rs +++ b/ext/src/snn/types/result.rs @@ -15,19 +15,33 @@ pub struct StepResult { } impl StepResult { - /// returns a new instance of the `StepResult` - pub const fn new(spiked: bool, v: T) -> Self { - Self { spiked, v } + /// returns a new instance of the `StepResult`; + /// + /// **Note**:: defaults to a state of being _not spiked_. + pub const fn new(v: T) -> Self { + Self { spiked: false, v } } - + /// returns a new, _spiked_ instance of the `StepResult` pub const fn spiked(v: T) -> Self { Self { spiked: true, v } } - - pub const fn not_spiked(v: T) -> Self { - Self { spiked: false, v } + #[inline] + /// consumes the current instance to create another that is said to have _spiked_. + pub fn spike(self) -> Self { + Self { + spiked: true, + ..self + } } - + #[inline] + /// consumes the current instance to create another that is said to have _not spiked_. + pub fn unspike(self) -> Self { + Self { + spiked: true, + ..self + } + } + /// returns true if the neuron spiked during this step pub const fn is_spiked(&self) -> bool { self.spiked } @@ -36,3 +50,9 @@ impl StepResult { &self.v } } + +impl PartialEq for StepResult { + fn eq(&self, other: &bool) -> bool { + &self.spiked == other + } +} diff --git a/ext/src/snn/utils.rs b/ext/src/snn/utils.rs new file mode 100644 index 00000000..aaade919 --- /dev/null +++ b/ext/src/snn/utils.rs @@ -0,0 +1,50 @@ +/* + Appellation: utils + Created At: 2025.12.08:15:53:49 + Contrib: @FL03 +*/ + +use super::{LIFNeuron, SynapticEvent}; +#[cfg(feature = "alloc")] +use alloc::vec::Vec; +use num_traits::{Float, FromPrimitive, NumAssign}; + +/// A basic method for _discovering_ the minimum external drive required to make a spiking +/// neuron spike +pub fn sweep_for_min_drive(step_size: T) -> T +where + T: Float + FromPrimitive + NumAssign, +{ + let dt = T::from_f32(0.1).unwrap(); + let t_sim = T::from_usize(1000).unwrap(); + let steps = (t_sim / dt).to_usize().unwrap(); + let presyn_spikes: Vec<(T, T)> = vec![]; // no extra synaptic drive + + let mut i_ext = T::zero(); + loop { + let mut neuron = LIFNeuron::::default(); + let mut events: Vec>> = vec![Vec::new(); steps + 1]; + for (t_spike, weight) in &presyn_spikes { + let idx = (*t_spike / dt).round().to_isize().unwrap(); + if idx >= 0 && (idx as usize) < events.len() { + events[idx as usize].push(SynapticEvent { weight: *weight }); + } + } + let mut spiked = false; + for step in 0..steps { + for ev in &events[step] { + neuron.apply_spike(ev.weight); + } + let res = neuron.step(dt, i_ext); + if res.is_spiked() { + spiked = true; + break; + } + } + if spiked { + break; + } + i_ext += step_size; // increment drive + } + i_ext +} diff --git a/ext/src/transformer/model.rs b/ext/src/transformer/model.rs index a34df757..2cf4531b 100644 --- a/ext/src/transformer/model.rs +++ b/ext/src/transformer/model.rs @@ -3,15 +3,14 @@ Contrib: @FL03 */ #[cfg(feature = "rand")] -use cnc::rand_distr; +use cnc::init::rand_distr::{Distribution, StandardNormal}; use cnc::{ DeepModelParams, Forward, Model, ModelFeatures, Norm, Params, ReLUActivation, SigmoidActivation, StandardModelConfig, Train, }; -use anyhow::Context; -use ndarray::prelude::*; -use ndarray::{Data, ScalarOperand}; +use ndarray::linalg::Dot; +use ndarray::{Array1, Array2, ArrayBase, ArrayView1, Data, Ix1, Ix2, ScalarOperand}; use num_traits::{Float, FromPrimitive, NumAssign}; #[derive(Clone, Debug)] @@ -94,7 +93,7 @@ where pub fn init(self) -> Self where T: 'static + Float + FromPrimitive, - rand_distr::StandardNormal: rand_distr::Distribution, + StandardNormal: Distribution, { let params = DeepModelParams::glorot_normal(self.features()); TransformerModel { params, ..self } @@ -131,23 +130,20 @@ where A: Float + FromPrimitive + ScalarOperand, V: ReLUActivation + SigmoidActivation, Params: Forward + Forward, - for<'a> &'a U: ndarray::linalg::Dot, Output = V> + core::ops::Add<&'a Array1>, + for<'a> &'a U: Dot, Output = V> + core::ops::Add<&'a Array1>, V: for<'a> core::ops::Add<&'a Array1, Output = V>, { type Output = V; - fn forward(&self, input: &U) -> Option { - let mut output = self.params().input().forward_then(&input, |y| y.relu())?; + fn forward(&self, input: &U) -> Self::Output { + let mut output = self.params().input().forward(input).relu(); for layer in self.params().hidden() { - output = layer.forward_then(&output, |y| y.relu())?; + output = layer.forward(&output).relu(); } - let y = self - .params() - .output() - .forward_then(&output, |y| y.sigmoid())?; - Some(y) + output = self.params().output().forward(&output).sigmoid(); + output } } @@ -203,28 +199,15 @@ where let mut activations = Vec::new(); activations.push(input.to_owned()); - let mut output = self - .params() - .input() - .forward(&input) - .context("Output layer failed to forward propagate during training...")? - .relu(); + let mut output = self.params().input().forward(&input).relu(); activations.push(output.to_owned()); // collect the activations of the hidden for layer in self.params().hidden() { - output = layer - .forward(&output) - .context("Hidden layer failed to forward propagate during training...")? - .relu(); + output = layer.forward(&output).relu(); activations.push(output.to_owned()); } - output = self - .params() - .output() - .forward(&output) - .context("Input layer failed to forward propagate during training...")? - .sigmoid(); + output = self.params().output().forward(&output).sigmoid(); activations.push(output.to_owned()); // Calculate output layer error @@ -238,8 +221,7 @@ where // Update output weights self.params_mut() .output_mut() - .backward(activations.last().unwrap(), &delta, lr) - .context("Backward propagation failed...")?; + .backward(activations.last().unwrap(), &delta, lr); let num_hidden = self.features().layers(); // Iterate through hidden layers in reverse order @@ -255,9 +237,7 @@ where }; // Normalize delta to prevent exploding gradients delta /= delta.l2_norm(); - self.params_mut().hidden_mut()[i] - .backward(&activations[i + 1], &delta, lr) - .context("Backward propagation failed...")?; + self.params_mut().hidden_mut()[i].backward(&activations[i + 1], &delta, lr); } /* Backpropagate to the input layer @@ -270,8 +250,7 @@ where delta /= delta.l2_norm(); // Normalize the delta to prevent exploding gradients self.params_mut() .input_mut() - .backward(&activations[1], &delta, lr) - .context("Input layer backward pass failed")?; + .backward(&activations[1], &delta, lr); Ok(loss) } diff --git a/ext/tests/snn.rs b/ext/tests/snn.rs index 620da0d6..ef208164 100644 --- a/ext/tests/snn.rs +++ b/ext/tests/snn.rs @@ -4,11 +4,11 @@ Contrib: @FL03 */ use approx::assert_abs_diff_eq; -use concision_ext::snn::SpikingNeuron; +use concision_ext::snn::LIFNeuron; #[test] fn test_snn_neuron_resting_no_input() { - let mut n = SpikingNeuron::default(); + let mut n = LIFNeuron::default(); let dt = 1.0; // simulate 100 ms with no input -> should not spike and v near v_rest for _ in 0..100 { @@ -16,31 +16,33 @@ fn test_snn_neuron_resting_no_input() { assert!(!res.is_spiked()); } let v = n.membrane_potential(); - assert_abs_diff_eq!(v, n.v_rest); + assert_abs_diff_eq!(v, &n.v_rest); } #[test] -// #[ignore = "Need to fix"] fn test_snn_neuron_spikes() { // params let dt = 1f64; let i_ext = 50f64; // large i_ext to force spiking // neuron - let mut n = SpikingNeuron::default(); + let mut n = LIFNeuron::default(); let mut spiked = false; let mut steps = 0usize; - // apply strong constant external current for a while + // run until spiked or max steps reached while !spiked && steps < 1000 { spiked = n.step(dt, i_ext).is_spiked(); steps += 1; } - assert!(spiked, "Neuron did not spike under strong current"); + assert!( + spiked, + "Neuron did not spike under a strong current (i_ext = {i_ext})" + ); } #[test] fn test_snn_neuron_synaptic_state_change() { - let mut n = SpikingNeuron::default(); - let before = n.synaptic_state(); - n.receive_spike(2.5); - assert!(n.synaptic_state() > before); + let mut n = LIFNeuron::default(); + let before = *n.synaptic_state(); + n.apply_spike(2.5); + assert!(*n.synaptic_state() > before); } diff --git a/flake.lock b/flake.lock new file mode 100644 index 00000000..e42d3e3e --- /dev/null +++ b/flake.lock @@ -0,0 +1,60 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1764389371, + "narHash": "sha256-Bq3kTfPl2q5pbJypVESXB1iYlcg35M7DWrmUS/aV27I=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "7d9be7b1c63840e80f7518979518e89a596a0060", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix index 3985a471..7830c5b6 100644 --- a/flake.nix +++ b/flake.nix @@ -15,7 +15,7 @@ { packages.default = rustPlatform.buildRustPackage { pname = "concision"; - version = "0.2.9"; + version = "0.3.0"; src = self; # "./."; # If Cargo.lock doesn't exist yet, remove or comment out this block: cargoLock = { diff --git a/init/Cargo.toml b/init/Cargo.toml index 7075650f..2afe0370 100644 --- a/init/Cargo.toml +++ b/init/Cargo.toml @@ -18,11 +18,10 @@ version.workspace = true all-features = false features = ["full"] rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" [package.metadata.release] no-dev-version = true -tag-name = "{{version}}" +tag-name = "v{{version}}" [lib] bench = false @@ -67,17 +66,17 @@ nightly = [] # ************* [FF:Dependencies] ************* std = [ - "alloc", - "ndarray/std", - "num-complex?/std", - "num-traits/std", - "num/std", - "rand/std", - "rand/std_rng", - "serde/std", - "strum/std", - "thiserror/std", - "tracing?/std" + "alloc", + "ndarray/std", + "num-complex?/std", + "num-traits/std", + "num/std", + "rand/std", + "rand/std_rng", + "serde/std", + "strum/std", + "thiserror/std", + "tracing?/std", ] wasi = [] @@ -97,13 +96,13 @@ rand = ["dep:rand", "dep:rand_distr", "num-complex?/rand", "num/rand", "rng"] rng = ["dep:getrandom", "rand?/small_rng", "rand?/thread_rng"] serde = [ - "dep:serde", - "dep:serde_derive", - "ndarray/serde", - "num-complex?/serde", - "num/serde", - "rand?/serde", - "rand_distr?/serde" + "dep:serde", + "dep:serde_derive", + "ndarray/serde", + "num-complex?/serde", + "num/serde", + "rand?/serde", + "rand_distr?/serde", ] tracing = ["dep:tracing"] diff --git a/init/src/distr/xavier.rs b/init/src/distr/xavier.rs index 1f51f133..6f22a4c0 100644 --- a/init/src/distr/xavier.rs +++ b/init/src/distr/xavier.rs @@ -69,7 +69,7 @@ mod impl_normal { /// tries creating a new [`Normal`] distribution with a mean of 0 and the computed /// standard deviation ($\sigma$) based on the number of inputs and outputs. pub fn distr(&self) -> crate::InitResult> { - Normal::new(T::zero(), self.std_dev()).map_err(Into::into) + Ok(Normal::new(T::zero(), self.std_dev())?) } /// returns a reference to the standard deviation of the distribution pub const fn std_dev(&self) -> T { diff --git a/init/src/error.rs b/init/src/error.rs index 2e045178..2349efda 100644 --- a/init/src/error.rs +++ b/init/src/error.rs @@ -2,33 +2,50 @@ appellation: error authors: @FL03 */ -/// this module defines the error type, [`InitError`], used for various initialization errors -/// that one may encounter within the library. +//! Initialization related errors and other useful primitives #[cfg(feature = "alloc")] use alloc::string::String; -use rand_distr::NormalError; -use rand_distr::uniform::Error as UniformError; -/// a private type alias for a [`Result`](core::result::Result) type that is used throughout -/// the library using an [`InitError`](InitError) as the error type. +/// a type alias for a [`Result`](core::result::Result) type that is used throughout +/// the library using an [`InitError`] as the error type. pub type InitResult = core::result::Result; +/// The [`InitError`] type enumerates various initialization errors while integrating with the +/// external errors largely focused on randomization. #[derive(Debug, thiserror::Error)] +#[non_exhaustive] pub enum InitError { #[cfg(feature = "alloc")] #[error("Failed to initialize with the given distribution: {0}")] DistributionError(String), + #[cfg(feature = "rng")] + #[error(transparent)] + RngError(#[from] getrandom::Error), #[cfg(feature = "rand")] - #[error("[NormalError] {0}")] - NormalError(NormalError), + #[error("[NormalError]: {0}")] + NormalError(rand_distr::NormalError), #[error(transparent)] #[cfg(feature = "rand")] - UniformError(#[from] UniformError), + UniformError(#[from] rand_distr::uniform::Error), + #[error("[WeibullError]: {0}")] + #[cfg(feature = "rand")] + WeibullError(rand_distr::WeibullError), } #[cfg(feature = "rand")] -impl From for InitError { - fn from(err: NormalError) -> Self { - InitError::NormalError(err) +mod rand_err { + use super::InitError; + use rand_distr::{NormalError, WeibullError}; + + impl From for InitError { + fn from(err: rand_distr::NormalError) -> Self { + InitError::NormalError(err) + } + } + + impl From for InitError { + fn from(err: rand_distr::WeibullError) -> Self { + InitError::WeibullError(err) + } } } diff --git a/init/src/lib.rs b/init/src/lib.rs index 129f3222..ab54d424 100644 --- a/init/src/lib.rs +++ b/init/src/lib.rs @@ -2,11 +2,13 @@ Appellation: concision-init Contrib: FL03 */ -//! One of the most important aspects of training neural networks and machine learning -//! lies within the _initialization_ of model parameters. Here, we work to provide additional -//! tools and utilities to facilitate effective initialization strategies including various -//! random distributions tailored directly to machine learning workloads such as: -//! Glorot (Xavier) initialization, LeCun initialization, etc. +//! Initialization related tools and utilities for neural networks and machine learning models. +//! This crate provides various initialization distributions and traits to facilitate +//! the effective initialization of model parameters. +//! +//! ## Features +//! +//! - `rand`: Enables random number generation functionalities using the `rand` crate. //! //! Implementors of the [`Initialize`] trait can leverage the various initialization //! distributions provided within this crate to initialize their model parameters in a @@ -17,7 +19,8 @@ clippy::module_inception, clippy::needless_doctest_main, clippy::should_implement_trait, - clippy::upper_case_acronyms + clippy::upper_case_acronyms, + rustdoc::redundant_explicit_links )] #![cfg_attr(not(feature = "std"), no_std)] @@ -76,18 +79,15 @@ pub mod distr { //! this module implements various random distributions optimized for neural network //! initialization. #[doc(inline)] - pub use self::prelude::*; + pub use self::{lecun::*, trunc::*, xavier::*}; pub mod lecun; pub mod trunc; pub mod xavier; pub(crate) mod prelude { - #[doc(inline)] pub use super::lecun::*; - #[doc(inline)] pub use super::trunc::*; - #[doc(inline)] pub use super::xavier::*; } } diff --git a/init/src/traits/init.rs b/init/src/traits/init.rs index b8919abb..ef58d36f 100644 --- a/init/src/traits/init.rs +++ b/init/src/traits/init.rs @@ -3,6 +3,16 @@ Contrib: @FL03 */ +#[cfg(feature = "rand")] +use rand::RngCore; +/// Initializes parameters and state from RNG and/or config. +/// Macro will implement this to produce shaped params/state. +pub trait Initialize { + type Output; + + fn init_random(rng: &mut R) -> Self::Output; +} + /// A trait for creating custom initialization routines for models or other entities. pub trait Init { /// consumes the current instance to initialize a new one diff --git a/init/src/traits/initialize.rs b/init/src/traits/initialize.rs index 77055cd3..11d9be47 100644 --- a/init/src/traits/initialize.rs +++ b/init/src/traits/initialize.rs @@ -24,9 +24,9 @@ where #[deprecated( since = "0.2.9", - note = "Please use the `InitRand` trait instead which provides more comprehensive functionality." + note = "Please use the `NdInit` trait instead which provides more comprehensive functionality." )] -pub trait Initialize: InitRand +pub trait InitRand: NdInit where D: Dimension, S: RawData, @@ -36,7 +36,7 @@ where /// The trait is similar to the `RandomExt` trait provided by the `ndarray_rand` crate, /// however, it is designed to be more generic, extensible, and optimized for neural network /// initialization routines. -pub trait InitRand::Elem>: Sized +pub trait NdInit::Elem>: Sized where D: Dimension, S: RawData, @@ -212,7 +212,7 @@ where ************ Implementations ************ */ -impl InitRand for ArrayBase +impl NdInit for ArrayBase where D: Dimension, S: RawData, diff --git a/init/src/utils/rand_utils.rs b/init/src/utils/rand_utils.rs index 04b868f7..fbb7e15c 100644 --- a/init/src/utils/rand_utils.rs +++ b/init/src/utils/rand_utils.rs @@ -2,7 +2,7 @@ Appellation: utils Contrib: FL03 */ -use crate::InitRand; +use crate::NdInit; use ndarray::{Array, ArrayBase, DataOwned, Dimension, IntoDimension, RawData, ShapeBuilder}; use num::Num; use num::complex::{Complex, ComplexDistribution}; diff --git a/init/tests/init.rs b/init/tests/init.rs index cf2ffc8e..0ceee1fb 100644 --- a/init/tests/init.rs +++ b/init/tests/init.rs @@ -4,7 +4,7 @@ */ extern crate concision_init as cnc; -use cnc::InitRand; +use cnc::NdInit; use cnc::distr::LecunNormal; use ndarray::prelude::*; diff --git a/macros/Cargo.toml b/macros/Cargo.toml index ceebad21..7a383b08 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -18,11 +18,10 @@ version.workspace = true all-features = false features = ["default"] rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" [package.metadata.release] no-dev-version = true -tag-name = "{{version}}" +tag-name = "v{{version}}" [lib] bench = false @@ -40,5 +39,5 @@ syn = { features = ["full"], version = "2" } default = [] nightly = [ - "proc-macro2/nightly" -] \ No newline at end of file + "proc-macro2/nightly", +] diff --git a/math.md b/math.md deleted file mode 100644 index 243cfeb2..00000000 --- a/math.md +++ /dev/null @@ -1,15 +0,0 @@ - -# Something - -$$ v \geq v_{thresh} \rightarrow{\delta}, v\leftarrow{v_{reset}}, \omega\mathrel{+}=b $$ - -$$ -\tau_m * \frac{dv}{dt} = -(v - v_{rest}) + R*(I_{ext} + I_{syn}) - \omega -$$ - -$$\tau_w * \frac{d\omega}{dt} = -\omega$$ -$$\tau_s * \frac{ds}{dt} = -s$$ - -where: - - $`v`$: membrane potential - - $\omega$: adaptation variable diff --git a/params/Cargo.toml b/params/Cargo.toml index 6c50c265..431c1f14 100644 --- a/params/Cargo.toml +++ b/params/Cargo.toml @@ -18,35 +18,30 @@ version.workspace = true all-features = false features = ["full"] rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" [package.metadata.release] no-dev-version = true -tag-name = "{{version}}" +tag-name = "v{{version}}" [lib] -crate-type = ["cdylib", "rlib"] bench = false -doc = true -doctest = true +crate-type = ["cdylib", "rlib"] +doctest = false test = true -[[test]] -name = "create" -required-features = ["std"] - [dependencies] concision-init = { workspace = true } concision-traits = { workspace = true } # custom variants = { workspace = true } # concurrency & parallelism -rayon = { optional = true, workspace = true } +rayon-core = { optional = true, workspace = true } # data & serialization serde = { features = ["derive"], optional = true, workspace = true } serde_derive = { optional = true, workspace = true } serde_json = { optional = true, workspace = true } # error-handling +anyhow = { workspace = true } thiserror = { workspace = true } # mathematics approx = { optional = true, workspace = true } @@ -58,16 +53,13 @@ getrandom = { default-features = false, optional = true, workspace = true } rand = { optional = true, workspace = true } rand_distr = { optional = true, workspace = true } -[dev-dependencies] -lazy_static = { workspace = true } - [features] default = ["std"] full = [ - "default", "approx", "complex", + "default", "json", "rand", "serde", @@ -81,6 +73,7 @@ nightly = [ # ************* [FF:Dependencies] ************* std = [ "alloc", + "anyhow/std", "concision-init/std", "concision-traits/std", "ndarray/std", @@ -99,11 +92,14 @@ wasi = [ ] wasm = [ - "getrandom?/wasm_js", "concision-init/wasm", "concision-traits/wasm", + "getrandom?/wasm_js", + "rayon-core?/web_spin_lock", ] + # ************* [FF:Dependencies] ************* + alloc = [ "concision-init/alloc", "concision-traits/alloc", @@ -113,8 +109,8 @@ alloc = [ ] approx = [ - "dep:approx", "concision-init/approx", + "dep:approx", "ndarray/approx", ] @@ -124,30 +120,30 @@ blas = [ ] complex = [ - "dep:num-complex", "concision-init/complex", + "dep:num-complex", ] json = ["alloc", "serde", "serde_json"] rand = [ - "dep:rand", - "dep:rand_distr", "concision-init/rand", "concision-traits/rand", + "dep:rand", + "dep:rand_distr", "num-complex?/rand", "rng", ] rayon = [ - "dep:rayon", + "dep:rayon-core", "ndarray/rayon", ] rng = [ - "dep:getrandom", "concision-init/rng", "concision-traits/rng", + "dep:getrandom", "rand?/small_rng", "rand?/thread_rng", ] diff --git a/params/src/error.rs b/params/src/error.rs index 58547e6f..8af4bf2b 100644 --- a/params/src/error.rs +++ b/params/src/error.rs @@ -12,14 +12,19 @@ pub type Result = core::result::Result; /// neural network. #[derive(Debug, thiserror::Error)] pub enum ParamsError { - #[error("Dimension Error: {0}")] - DimensionalError(String), #[error("Invalid biases")] InvalidBiases, #[error("Invalid weights")] InvalidWeights, - #[error("Invalid input shape")] - InvalidInputShape, + #[error( + "Unable to complete the operation due to a mismatch between shapes: expected {expected:?}, found {found:?}" + )] + MismatchedShapes { + expected: &'static [usize], + found: &'static [usize], + }, + #[error("An invalid tensor of length {0} was provided")] + InvalidLength(usize), #[error("Invalid output shape")] InvalidOutputShape, #[error("Invalid parameter: {0}")] @@ -28,6 +33,8 @@ pub enum ParamsError { InvalidParameterType, #[error("Invalid parameter value")] InvalidParameterValue, + #[error("Must be non-empty")] + EmptyInput, #[error(transparent)] ShapeError(#[from] ndarray::ShapeError), } diff --git a/params/src/impls/impl_params.rs b/params/src/impls/impl_params.rs index fdadbdfd..7186aaba 100644 --- a/params/src/impls/impl_params.rs +++ b/params/src/impls/impl_params.rs @@ -2,7 +2,7 @@ appellation: impl_params authors: @FL03 */ -use crate::params::ParamsBase; +use crate::params_base::ParamsBase; use crate::traits::{Biased, Weighted}; use core::iter::Once; use ndarray::{ArrayBase, Data, DataOwned, Dimension, Ix1, Ix2, RawData}; diff --git a/params/src/impls/impl_params_deprecated.rs b/params/src/impls/impl_params_deprecated.rs index 4df62350..fc30e57c 100644 --- a/params/src/impls/impl_params_deprecated.rs +++ b/params/src/impls/impl_params_deprecated.rs @@ -2,7 +2,7 @@ appellation: impl_params_iter authors: @FL03 */ -use crate::params::ParamsBase; +use crate::params_base::ParamsBase; use ndarray::{Dimension, RawData}; diff --git a/params/src/impls/impl_params_init.rs b/params/src/impls/impl_params_init.rs deleted file mode 100644 index 319126a0..00000000 --- a/params/src/impls/impl_params_init.rs +++ /dev/null @@ -1,14 +0,0 @@ -/* - appellation: impl_params_init - authors: @FL03 -*/ -use crate::params::ParamsBase; - -use ndarray::{Dimension, RawData}; - -impl ParamsBase -where - D: Dimension, - S: RawData, -{ -} diff --git a/params/src/impls/impl_params_iter.rs b/params/src/impls/impl_params_iter.rs index 1531aba0..58e07f91 100644 --- a/params/src/impls/impl_params_iter.rs +++ b/params/src/impls/impl_params_iter.rs @@ -2,7 +2,7 @@ appellation: impl_params_iter authors: @FL03 */ -use crate::params::ParamsBase; +use crate::params_base::ParamsBase; use crate::iter::{Iter, IterMut}; use ndarray::iter as nditer; @@ -65,7 +65,7 @@ where { self.weights().iter() } - /// returns a mutable iterator over the weights; see [`iter_mut`](ArrayBase::iter_mut) for more + /// returns a mutable iterator over the weights; see [`iter_mut`](ndarray::iter::IterMut) for more pub fn iter_weights_mut(&mut self) -> nditer::IterMut<'_, A, D> where S: DataMut, diff --git a/params/src/impls/impl_params_ops.rs b/params/src/impls/impl_params_ops.rs index 390e4cf0..13283804 100644 --- a/params/src/impls/impl_params_ops.rs +++ b/params/src/impls/impl_params_ops.rs @@ -3,10 +3,12 @@ Contrib: @FL03 */ use crate::{Params, ParamsBase}; -use concision_traits::{ApplyGradient, ApplyGradientExt, Backward, Forward, Norm}; +use concision_traits::{Backward, Forward, Norm}; use ndarray::linalg::Dot; -use ndarray::{ArrayBase, Axis, Data, DataMut, Dimension, Ix0, Ix1, Ix2, ScalarOperand}; -use num_traits::{Float, FromPrimitive}; +use ndarray::{ + Array, ArrayBase, ArrayView, Data, Dimension, Ix0, Ix1, Ix2, RemoveAxis, ScalarOperand, +}; +use num_traits::{Float, FromPrimitive, Num}; impl ParamsBase where @@ -14,15 +16,16 @@ where D: Dimension, S: Data, { - /// perform a single backpropagation step - pub fn backward(&mut self, input: &X, grad: &Y, lr: A) -> Option + /// execute a single backward propagation + pub fn backward(&mut self, input: &X, grad: &Y, lr: A) where - Self: Backward, + Self: Backward, { >::backward(self, input, grad, lr) } - /// forward propagation - pub fn forward(&self, input: &X) -> Option + /// invoke a single forward step; this method is simply a convienience method implemented + /// to reduce the number of `Forward` imports. + pub fn forward(&self, input: &X) -> Y where Self: Forward, { @@ -36,7 +39,7 @@ where D: Dimension, S: Data, { - /// Returns the L1 norm of the parameters (bias and weights). + /// computes the `l1` normalization of the current weights and biases pub fn l1_norm(&self) -> A { let bias = self.bias().l1_norm(); let weights = self.weights().l1_norm(); @@ -48,201 +51,39 @@ where let weights = self.weights().l2_norm(); bias + weights } - - /// a convenience method used to apply a gradient to the parameters using the given - /// learning rate. - pub fn apply_gradient(&mut self, grad: &Delta, lr: A) -> Option - where - S: DataMut, - Self: ApplyGradient, - { - >::apply_gradient(self, grad, lr) - } - - pub fn apply_gradient_with_decay(&mut self, grad: &Grad, lr: A, decay: A) -> Option - where - S: DataMut, - Self: ApplyGradient, - { - >::apply_gradient_with_decay(self, grad, lr, decay) - } - - pub fn apply_gradient_with_momentum( - &mut self, - grad: &Grad, - lr: A, - momentum: A, - velocity: &mut V, - ) -> Option - where - S: DataMut, - Self: ApplyGradientExt, - { - >::apply_gradient_with_momentum( - self, grad, lr, momentum, velocity, - ) - } - - pub fn apply_gradient_with_decay_and_momentum( - &mut self, - grad: &Grad, - lr: A, - decay: A, - momentum: A, - velocity: &mut V, - ) -> Option - where - S: DataMut, - Self: ApplyGradientExt, - { - >::apply_gradient_with_decay_and_momentum( - self, grad, lr, decay, momentum, velocity, - ) - } } -impl ApplyGradient, A> for ParamsBase -where - A: Float + FromPrimitive + ScalarOperand, - S: DataMut, - T: Data, - D: Dimension, -{ - type Output = (); - - fn apply_gradient(&mut self, grad: &ParamsBase, lr: A) -> Option { - // apply the bias gradient - self.bias_mut().apply_gradient(grad.bias(), lr)?; - // apply the weight gradient - self.weights_mut().apply_gradient(grad.weights(), lr)?; - Some(()) - } - - fn apply_gradient_with_decay( - &mut self, - grad: &ParamsBase, - lr: A, - decay: A, - ) -> Option { - // apply the bias gradient - self.bias_mut() - .apply_gradient_with_decay(grad.bias(), lr, decay)?; - // apply the weight gradient - self.weights_mut() - .apply_gradient_with_decay(grad.weights(), lr, decay)?; - Some(()) - } -} - -impl ApplyGradientExt, A> for ParamsBase +impl Forward for ParamsBase where - A: Float + FromPrimitive + ScalarOperand, - S: DataMut, - T: Data, + A: Clone, D: Dimension, -{ - type Velocity = Params; - - fn apply_gradient_with_momentum( - &mut self, - grad: &ParamsBase, - lr: A, - momentum: A, - velocity: &mut Self::Velocity, - ) -> Option<()> { - // apply the bias gradient - self.bias_mut().apply_gradient_with_momentum( - grad.bias(), - lr, - momentum, - velocity.bias_mut(), - )?; - // apply the weight gradient - self.weights_mut().apply_gradient_with_momentum( - grad.weights(), - lr, - momentum, - velocity.weights_mut(), - )?; - Some(()) - } - - fn apply_gradient_with_decay_and_momentum( - &mut self, - grad: &ParamsBase, - lr: A, - decay: A, - momentum: A, - velocity: &mut Self::Velocity, - ) -> Option<()> { - // apply the bias gradient - self.bias_mut().apply_gradient_with_decay_and_momentum( - grad.bias(), - lr, - decay, - momentum, - velocity.bias_mut(), - )?; - // apply the weight gradient - self.weights_mut().apply_gradient_with_decay_and_momentum( - grad.weights(), - lr, - decay, - momentum, - velocity.weights_mut(), - )?; - Some(()) - } -} - -impl Backward, ArrayBase> for Params -where - A: Float + FromPrimitive + ScalarOperand, S: Data, - T: Data, + for<'a> ArrayView<'a, A, D>: Dot, + Y: for<'a> core::ops::Add<&'a ArrayBase, Output = Z>, { - type Elem = A; - type Output = A; + type Output = Z; - fn backward( - &mut self, - input: &ArrayBase, - delta: &ArrayBase, - gamma: Self::Elem, - ) -> Option { - // compute the weight gradient - let weight_delta = delta.t().dot(input); - // update the weights and bias - self.weights_mut().apply_gradient(&weight_delta, gamma)?; - self.bias_mut() - .apply_gradient(&delta.sum_axis(Axis(0)), gamma)?; - // return the sum of the squared delta - Some(delta.pow2().sum()) + fn forward(&self, input: &X) -> Self::Output { + self.weights().t().dot(input) + self.bias() } } -impl Backward, ArrayBase> for Params +impl Backward, ArrayBase> for Params where A: Float + FromPrimitive + ScalarOperand, S: Data, T: Data, { type Elem = A; - type Output = A; fn backward( &mut self, - input: &ArrayBase, + input: &ArrayBase, delta: &ArrayBase, gamma: Self::Elem, - ) -> Option { - // compute the weight gradient - let weight_delta = input * delta; - // update the weights and bias - self.weights_mut().apply_gradient(&weight_delta, gamma)?; - self.bias_mut().apply_gradient(delta, gamma)?; - // return the sum of the squared delta - Some(delta.pow2().sum()) + ) { + self.weights_mut().scaled_add(gamma, &(input * delta)); + self.bias_mut().scaled_add(gamma, delta); } } @@ -253,64 +94,36 @@ where T: Data, { type Elem = A; - type Output = A; fn backward( &mut self, input: &ArrayBase, delta: &ArrayBase, gamma: Self::Elem, - ) -> Option { - // compute the weight gradient - let dw = &self.weights * delta.t().dot(input); - // update the weights and bias - self.weights_mut().apply_gradient(&dw, gamma)?; - self.bias_mut().apply_gradient(delta, gamma)?; - // return the sum of the squared delta - Some(delta.pow2().sum()) + ) { + self.weights_mut().scaled_add(gamma, &(delta * input)); + self.bias_mut().scaled_add(gamma, delta); } } -impl Backward, ArrayBase> for Params +impl Backward, ArrayBase> for Params where - A: Float + FromPrimitive + ScalarOperand, - S: Data, - T: Data, + A: 'static + Copy + Num, + D1: RemoveAxis, + D2: Dimension, + S1: Data, + S2: Data, + for<'b> &'b ArrayBase: Dot, Output = Array>, { type Elem = A; - type Output = A; fn backward( &mut self, - input: &ArrayBase, - delta: &ArrayBase, + input: &ArrayBase, + delta: &ArrayBase, gamma: Self::Elem, - ) -> Option { - // compute the weight gradient - let weight_delta = input.dot(&delta.t()); - // compute the bias gradient - let bias_delta = delta.sum_axis(Axis(0)); - - self.weights_mut().apply_gradient(&weight_delta, gamma)?; - self.bias_mut().apply_gradient(&bias_delta, gamma)?; - // return the sum of the squared delta - let y = input.dot(self.weights()) + self.bias(); - let res = (&y - delta).pow2().sum(); - Some(res) - } -} - -impl Forward for ParamsBase -where - A: Clone, - D: Dimension, - S: Data, - for<'a> X: Dot, Output = Y>, - Y: for<'a> core::ops::Add<&'a ArrayBase, Output = Z>, -{ - type Output = Z; - - fn forward(&self, input: &X) -> Option { - Some(input.dot(self.weights()) + self.bias()) + ) { + self.weights_mut().backward(input, delta, gamma); + self.bias_mut().scaled_add(gamma, delta); } } diff --git a/params/src/impls/impl_params_rand.rs b/params/src/impls/impl_params_rand.rs index 248d1b1c..31094ce2 100644 --- a/params/src/impls/impl_params_rand.rs +++ b/params/src/impls/impl_params_rand.rs @@ -3,9 +3,9 @@ Created At: 2025.11.26:15:28:12 Contrib: @FL03 */ -use crate::params::ParamsBase; +use crate::params_base::ParamsBase; -use concision_init::InitRand; +use concision_init::NdInit; use ndarray::{ ArrayBase, Axis, DataOwned, Dimension, RawData, RemoveAxis, ScalarOperand, ShapeBuilder, }; @@ -53,7 +53,7 @@ where } } -impl InitRand for ParamsBase +impl NdInit for ParamsBase where D: RemoveAxis, S: RawData, diff --git a/params/src/impls/impl_params_serde.rs b/params/src/impls/impl_params_serde.rs index 07bce482..dd410bcc 100644 --- a/params/src/impls/impl_params_serde.rs +++ b/params/src/impls/impl_params_serde.rs @@ -2,7 +2,7 @@ appellation: impl_params_serde authors: @FL03 */ -use crate::params::ParamsBase; +use crate::params_base::ParamsBase; use ndarray::{Data, DataOwned, Dimension, RawData}; use serde::de::{Deserialize, Deserializer, Error, Visitor}; use serde::ser::{Serialize, SerializeStruct, Serializer}; diff --git a/params/src/lib.rs b/params/src/lib.rs index 17ddb852..e6f7bac5 100644 --- a/params/src/lib.rs +++ b/params/src/lib.rs @@ -23,17 +23,17 @@ //! #![cfg_attr(not(feature = "std"), no_std)] #![allow( - clippy::missing_saftey_doc, + clippy::missing_safety_doc, clippy::module_inception, clippy::needless_doctest_main, - clippy::upper_case_acronyms + clippy::should_implement_trait, + clippy::upper_case_acronyms, + rustdoc::redundant_explicit_links )] #[cfg(feature = "alloc")] extern crate alloc; -extern crate concision_init as cnc_init; -extern crate concision_traits as cnc_traits; -extern crate ndarray as nda; +extern crate ndarray as nd; #[cfg(all(not(feature = "alloc"), not(feature = "std")))] compiler_error! { @@ -43,16 +43,21 @@ compiler_error! { pub mod error; pub mod iter; -mod params; +mod params_base; + +#[macro_use] +pub(crate) mod macros { + #[macro_use] + pub mod seal; +} mod impls { mod impl_params; - #[allow(deprecated)] - mod impl_params_deprecated; - #[cfg(feature = "rand")] - mod impl_params_init; mod impl_params_iter; mod impl_params_ops; + + #[allow(deprecated)] + mod impl_params_deprecated; #[cfg(feature = "rand")] mod impl_params_rand; #[cfg(feature = "serde")] @@ -61,8 +66,9 @@ mod impls { pub mod traits { //! Traits for working with model parameters - pub use self::wnb::*; + pub use self::{param::*, wnb::*}; + mod param; mod wnb; } @@ -76,12 +82,12 @@ mod types { // re-exports #[doc(inline)] -pub use self::{error::*, params::ParamsBase, traits::*, types::*}; +pub use self::{error::*, params_base::ParamsBase, traits::*, types::*}; // prelude #[doc(hidden)] pub mod prelude { pub use crate::error::ParamsError; - pub use crate::params::*; + pub use crate::params_base::*; pub use crate::traits::*; pub use crate::types::*; } diff --git a/params/src/macros/seal.rs b/params/src/macros/seal.rs new file mode 100644 index 00000000..05f2bbf6 --- /dev/null +++ b/params/src/macros/seal.rs @@ -0,0 +1,50 @@ +/* + Appellation: seal + Contrib: FL03 +*/ +//! The public parts of this private module are used to create traits +//! that cannot be implemented outside of our own crate. This way we +//! can feel free to extend those traits without worrying about it +//! being a breaking change for other implementations. +//! +//! ## Usage +//! +//! To define a private trait, you can use the [`private!`] macro, which will define a hidden +//! method `__private__` that can only be implemented within the crate. + +/// If this type is pub but not publicly reachable, third parties +/// can't name it and can't implement traits using it. +#[allow(dead_code)] +pub struct Seal; +/// the [`private!`] macro is used to seal a particular trait, defining a hidden method that +/// may only be implemented within the bounds of the crate. +#[allow(unused_macros)] +macro_rules! private { + () => { + /// This trait is private to implement; this method exists to make it + /// impossible to implement outside the crate. + #[doc(hidden)] + fn __private__(&self) -> $crate::macros::seal::Seal; + }; +} +/// the [`seal!`] macro is used to implement a private method on a type, which is used to seal +/// the type so that it cannot be implemented outside of the crate. +#[allow(unused_macros)] +macro_rules! seal { + () => { + fn __private__(&self) -> $crate::macros::seal::Seal { + $crate::macros::seal::Seal + } + }; +} +/// this macros is used to implement a trait for a type, sealing it so that +/// it cannot be implemented outside of the crate. This is most usefuly for creating other +/// macros that can be used to implement some raw, sealed trait on the given _types_. +#[allow(unused_macros)] +macro_rules! sealed { + (impl$(<$($T:ident),*>)? $trait:ident for $name:ident$(<$($V:ident),*>)? $(where $($rest:tt)*)?) => { + impl$(<$($T),*>)? $trait for $name$(<$($V),*>)? $(where $($rest)*)? { + seal!(); + } + }; +} diff --git a/params/src/params.rs b/params/src/params_base.rs similarity index 87% rename from params/src/params.rs rename to params/src/params_base.rs index 3b2e943d..0ef143f4 100644 --- a/params/src/params.rs +++ b/params/src/params_base.rs @@ -7,17 +7,21 @@ use ndarray::{ ShapeBuilder, }; -/// The [`ParamsBase`] struct is a generic container for a set of weights and biases for a -/// model where the bias tensor is always `n-1` dimensions smaller than the `weights` tensor. +/// The [`ParamsBase`] implementation aims to provide a generic, n-dimensional weight and bias +/// pair for a model (or layer). The object requires the bias tensor to be a single dimension +/// smaller than the weights tensor. +/// +/// Therefore, we allow the weight tensor to be the _shape_ of the parameters, using the shape +/// as the basis for the bias tensor by removing one axi (typically the first axis). /// Consequently, this constrains the [`ParamsBase`] implementation to only support dimensions -/// that can be reduced by one axis (i.e. $\mbox{rank}(D)>0$), which is typically the "zero-th" axis. +/// that can be reduced by one axis, typically the "zero-th" axis: $`\mbox{rank}(D)>0`$. pub struct ParamsBase::Elem> where D: Dimension, S: RawData, { - pub(crate) bias: ArrayBase, - pub(crate) weights: ArrayBase, + pub bias: ArrayBase, + pub weights: ArrayBase, } impl ParamsBase @@ -39,9 +43,8 @@ where F: Fn() -> A, { let shape = shape.into_shape_with_order(); - let bshape = shape.raw_dim().remove_axis(Axis(0)); - // initialize the bias and weights using the provided function for each element - let bias = ArrayBase::from_shape_fn(bshape, |_| init()); + // initialize the bias using a shape that is 1 rank lower then the weights + let bias = ArrayBase::from_shape_fn(shape.raw_dim().remove_axis(Axis(0)), |_| init()); let weights = ArrayBase::from_shape_fn(shape, |_| init()); // create a new instance from the generated bias and weights Self::new(bias, weights) @@ -234,7 +237,7 @@ where self.weights().len() + self.bias().len() } /// returns an owned instance of the parameters - pub fn to_owned(&self) -> ParamsBase, D> + pub fn to_owned(&self) -> ParamsBase, D> where A: Clone, S: DataOwned, @@ -243,10 +246,7 @@ where } /// change the shape of the parameters; the shape of the bias parameters is determined by /// removing the "zero-th" axis of the given shape - pub fn to_shape( - &self, - shape: Sh, - ) -> crate::Result, Sh::Dim>> + pub fn to_shape(&self, shape: Sh) -> crate::Result, Sh::Dim>> where A: Clone, S: DataOwned, @@ -261,24 +261,24 @@ where } /// returns a new [`ParamsBase`] instance with the same paramaters, but using a shared /// representation of the data; - pub fn to_shared(&self) -> ParamsBase, D> + pub fn to_shared(&self) -> ParamsBase, D> where A: Clone, S: Data, { ParamsBase::new(self.bias().to_shared(), self.weights().to_shared()) } - /// returns a "view" of the parameters; see [view](ArrayBase::view) for more information - pub fn view(&self) -> ParamsBase, D> + /// returns a "view" of the parameters; see [`view`](ndarray::ViewRepr) for more information + pub fn view(&self) -> ParamsBase, D> where S: Data, { ParamsBase::new(self.bias().view(), self.weights().view()) } - /// returns mutable view of the parameters; see [view_mut](ArrayBase::view_mut) for more information - pub fn view_mut(&mut self) -> ParamsBase, D> + /// returns mutable view of the parameters + pub fn view_mut(&mut self) -> ParamsBase, D> where - S: ndarray::DataMut, + S: DataMut, { ParamsBase::new(self.bias.view_mut(), self.weights.view_mut()) } diff --git a/params/src/traits/param.rs b/params/src/traits/param.rs new file mode 100644 index 00000000..8d3fbfed --- /dev/null +++ b/params/src/traits/param.rs @@ -0,0 +1,214 @@ +/* + Appellation: param + Created At: 2025.12.08:16:03:55 + Contrib: @FL03 +*/ +/// The [`RawParam`] trait is used to denote objects capable of being used as a paramater +/// within a neural network or machine learning context. More over, it provides us with an +/// ability to associate some generic element type with the parameter and thus allows us to +/// consider so-called _parameter spaces_. If we allow a parameter space to simply be a +/// collection of points then we can refine the definition downstream to consider specific +/// interpolations, distributions, or manifolds. In other words, we are trying to construct +/// a tangible configuration space for our models so that we can reason about optimization +/// and training in a more formal manner. +/// +/// **Note**: This trait is sealed and cannot be implemented outside of this crate. +pub trait RawParam { + type Elem: ?Sized; + + private!(); +} + +/// The [`ScalarParam`] trait naturally extends the [`RawParameter`] trait to define a +/// scaler as a parameter whose element type is itself. This is useful for defining +/// parameters which are simple scalars such as `f32` or `i64`. +pub trait ScalarParam: RawParam + Sized { + private!(); +} + +pub trait TensorParams: RawParam { + type Shape: ?Sized; + /// returns the number of dimensions of the parameter + fn rank(&self) -> usize; + /// returns the shape of the parameter as a slice + fn shape(&self) -> &Self::Shape; + /// returns the size of the parameter + fn size(&self) -> usize; +} + +/* + ************* Implementations ************* +*/ +use crate::ParamsBase; +use ndarray::{ArrayBase, Dimension, RawData}; + +impl RawParam for &T +where + T: RawParam, +{ + type Elem = T::Elem; + + seal! {} +} + +impl RawParam for &mut T +where + T: RawParam, +{ + type Elem = T::Elem; + + seal! {} +} + +impl ScalarParam for T +where + T: RawParam, +{ + seal!(); +} + +macro_rules! impl_param { + ($($T:ty),* $(,)?) => { + $(impl_param!(@impl $T);)* + }; + (@impl $T:ty) => { + impl RawParam for $T { + type Elem = $T; + + seal! {} + } + + impl TensorParams for $T { + type Shape = [usize; 0]; + + fn rank(&self) -> usize { + 0 + } + + fn shape(&self) -> &Self::Shape { + &[] + } + + fn size(&self) -> usize { + 1 + } + } + }; +} + +impl_param! { + u8, u16, u32, u64, u128, usize, + i8, i16, i32, i64, i128, isize, + f32, f64, + bool, char, str +} + +#[cfg(feature = "alloc")] +impl RawParam for alloc::string::String { + type Elem = u8; + + seal! {} +} + +impl RawParam for ArrayBase +where + D: Dimension, + S: RawData, +{ + type Elem = A; + + seal! {} +} + +impl TensorParams for ArrayBase +where + D: Dimension, + S: RawData, +{ + type Shape = [usize]; + + fn rank(&self) -> usize { + self.ndim() + } + + fn shape(&self) -> &Self::Shape { + self.shape() + } + + fn size(&self) -> usize { + self.len() + } +} + +impl RawParam for ParamsBase +where + D: Dimension, + S: RawData, +{ + type Elem = A; + + seal! {} +} + +impl TensorParams for ParamsBase +where + D: Dimension, + S: RawData, +{ + type Shape = [usize]; + + fn rank(&self) -> usize { + self.weights().ndim() + } + + fn shape(&self) -> &[usize] { + self.weights().shape() + } + + fn size(&self) -> usize { + self.weights().len() + } +} + +impl RawParam for [T; N] +where + T: RawParam, +{ + type Elem = T::Elem; + + seal! {} +} + +impl TensorParams for [T; N] +where + T: RawParam, +{ + type Shape = [usize; 1]; + + fn rank(&self) -> usize { + 1 + } + + fn shape(&self) -> &Self::Shape { + &[N] + } + + fn size(&self) -> usize { + N + } +} + +#[cfg(feature = "alloc")] +mod impl_alloc { + use super::*; + use alloc::vec::Vec; + + impl RawParam for Vec + where + T: RawParam, + { + type Elem = T::Elem; + + seal! {} + } +} diff --git a/params/src/traits/wnb.rs b/params/src/traits/wnb.rs index 5e19f526..23ca94df 100644 --- a/params/src/traits/wnb.rs +++ b/params/src/traits/wnb.rs @@ -1,3 +1,8 @@ +/* + Appellation: wnb + Created At: 2025.11.28:21:21:42 + Contrib: @FL03 +*/ use ndarray::{ArrayBase, Data, DataMut, Dimension, RawData}; pub trait Weighted::Elem>: Sized @@ -27,7 +32,7 @@ where *self.weights_mut() = weights; self } - /// returns an iterator over the weights + /// returns an iterator over the weights; see [`iter`](ndarray::iter::Iter) for more information fn iter_weights<'a>(&'a self) -> ndarray::iter::Iter<'a, S::Elem, D> where S: Data + 'a, @@ -35,7 +40,7 @@ where { self.weights().iter() } - /// returns a mutable iterator over the weights; see [`iter_mut`](ArrayBase::iter_mut) for more + /// returns a mutable iterator over the weights; see [`iter_mut`](ndarray::iter::IterMut) for more information fn iter_weights_mut<'a>(&'a mut self) -> ndarray::iter::IterMut<'a, S::Elem, D> where S: DataMut + 'a, diff --git a/params/src/types/aliases.rs b/params/src/types/aliases.rs index fed92869..f787c182 100644 --- a/params/src/types/aliases.rs +++ b/params/src/types/aliases.rs @@ -2,7 +2,7 @@ appellation: aliases authors: @FL03 */ -use crate::params::ParamsBase; +use crate::params_base::ParamsBase; use ndarray::{CowRepr, Ix2, OwnedArcRepr, OwnedRepr, RawViewRepr, ViewRepr}; diff --git a/params/tests/create.rs b/params/tests/create.rs index 6e8abaef..aaa517d4 100644 --- a/params/tests/create.rs +++ b/params/tests/create.rs @@ -2,13 +2,12 @@ Appellation: params Contrib: @FL03 */ -extern crate concision_params as cnc; use concision_params::Params; #[test] fn test_params_ones() { - // weights retain the given shape (d_in, d_out) - // bias retains the shape (d_out,) + // weights retain the given shape (in, out) + // bias retains the shape (out,) let ones = Params::::ones((3, 4)); assert_eq!(ones.dim(), (3, 4)); assert_eq!(ones.bias().dim(), 4); @@ -20,8 +19,8 @@ fn test_params_ones() { #[test] fn test_params_zeros() { - // weights retain the given shape (d_in, d_out) - // bias retains the shape (d_out,) + // weights retain the given shape (in, out) + // bias retains the shape (out,) let zeros = Params::::zeros((3, 4)); assert_eq!(zeros.dim(), (3, 4)); assert_eq!(zeros.bias().dim(), 4); @@ -34,8 +33,8 @@ fn test_params_zeros() { #[test] #[cfg(feature = "rand")] -fn test_params_init_rand() { - use concision_init::InitRand; +fn test_params_init_rand() -> anyhow::Result<()> { + use concision_init::NdInit; let lecun = Params::::lecun_normal((3, 4)); assert_eq!(lecun.dim(), (3, 4)); @@ -49,4 +48,6 @@ fn test_params_init_rand() { assert_ne!(glorot_norm, glorot_uniform); let truncnorm = Params::::truncnorm((3, 4), 0.0, 1.0).expect("truncnorm failed"); assert_eq!(truncnorm.dim(), (3, 4)); + + Ok(()) } diff --git a/params/tests/propagate.rs b/params/tests/propagate.rs new file mode 100644 index 00000000..7041232a --- /dev/null +++ b/params/tests/propagate.rs @@ -0,0 +1,23 @@ +/* + Appellation: params + Contrib: @FL03 +*/ +use concision_params::Params; + +use ndarray::{Ix2, array}; + +/* + Verify the dimensionality of the output of forward propagation. + + Given some parameters of shape (in, out) and given some input Params of shape (in, out) and + an input of shape (..., in,), the output should be of shape (out,...). +*/ +#[test] +fn test_params_fwd_dimensionality() { + let params = Params::::ones((3, 4)); + let input = array![1.0, 2.0, 3.0]; + // should be of shape 4: + let output = params.forward(&input); + assert_eq!(output.dim(), 4); + assert_eq!(output, array![7.0, 7.0, 7.0, 7.0]); +} diff --git a/traits/Cargo.toml b/traits/Cargo.toml index 86370343..251af85c 100644 --- a/traits/Cargo.toml +++ b/traits/Cargo.toml @@ -18,17 +18,15 @@ version.workspace = true all-features = false features = ["full"] rustc-args = ["--cfg", "docsrs"] -version = "v{{version}}" [package.metadata.release] no-dev-version = true -tag-name = "{{version}}" +tag-name = "v{{version}}" [lib] bench = false crate-type = ["cdylib", "rlib"] -doc = true -doctest = true +doctest = false test = true [[test]] @@ -41,6 +39,7 @@ variants = { workspace = true } # data structures ndarray = { workspace = true } # error-handling +anyhow = { workspace = true } thiserror = { workspace = true } # macros & utilities paste = { workspace = true } @@ -50,7 +49,7 @@ num-complex = { optional = true, workspace = true } num-integer = { workspace = true } num-traits = { workspace = true } # random -getrandom = { default-features = false, optional = true, workspace = true } +getrandom = { optional = true, workspace = true } rand = { optional = true, workspace = true } rand_distr = { optional = true, workspace = true } @@ -58,9 +57,9 @@ rand_distr = { optional = true, workspace = true } default = ["std"] full = [ - "default", - "approx", + "approx", "complex", + "default", "rand", ] @@ -68,14 +67,15 @@ nightly = [] # ************* [FF:Dependencies] ************* std = [ - "alloc", - "ndarray/std", - "num-complex?/std", + "alloc", + "anyhow/std", + "ndarray/std", + "num-complex?/std", "num-integer/std", - "num-traits/std", - "rand?/std", + "num-traits/std", + "rand?/std", "rand?/std_rng", - "thiserror/std" + "thiserror/std", ] wasi = [] @@ -90,15 +90,17 @@ blas = ["ndarray/blas"] complex = ["dep:num-complex"] +rayon = ["ndarray/rayon"] + rand = [ - "dep:rand", - "dep:rand_distr", - "num-complex?/rand", - "rng" + "dep:rand", + "dep:rand_distr", + "num-complex?/rand", + "rng", ] rng = [ - "dep:getrandom", - "rand?/small_rng", - "rand?/thread_rng" + "dep:getrandom", + "rand?/small_rng", + "rand?/thread_rng", ] diff --git a/traits/src/container.rs b/traits/src/container.rs index 18f2099c..826e8a75 100644 --- a/traits/src/container.rs +++ b/traits/src/container.rs @@ -9,6 +9,29 @@ pub trait Container { type Item; } +pub trait KeyValue { + type Cont<_K, _V>; + type Key; + type Value; + + fn key(&self) -> &Self::Key; + fn value(&self) -> &Self::Value; +} + +impl KeyValue for (K, V) { + type Cont<_K, _V> = (_K, _V); + type Key = K; + type Value = V; + + fn key(&self) -> &Self::Key { + &self.0 + } + + fn value(&self) -> &Self::Value { + &self.1 + } +} + macro_rules! container { ($( $($container:ident)::*<$A:ident $(, $B:ident)?> diff --git a/traits/src/entropy.rs b/traits/src/entropy.rs index a3fa711d..e5c5785c 100644 --- a/traits/src/entropy.rs +++ b/traits/src/entropy.rs @@ -26,6 +26,6 @@ where type Output = A; fn cross_entropy(&self) -> Self::Output { - self.mapv(|x| -x.ln()).mean().unwrap() + self.ln().mean().unwrap() } } diff --git a/traits/src/lib.rs b/traits/src/lib.rs index 69d93c29..257f629d 100644 --- a/traits/src/lib.rs +++ b/traits/src/lib.rs @@ -7,7 +7,9 @@ clippy::missing_safety_doc, clippy::module_inception, clippy::needless_doctest_main, - clippy::upper_case_acronyms + clippy::should_implement_trait, + clippy::upper_case_acronyms, + rustdoc::redundant_explicit_links )] #![cfg_attr(not(feature = "std"), no_std)] #![cfg_attr(feature = "nightly", feature(allocator_api))] @@ -15,11 +17,12 @@ #[cfg(not(any(feature = "std", feature = "alloc")))] compiler_error! { - "At least one of the 'std' or 'alloc' features must be enabled." + "At least one of the \"std\" or \"alloc\" features must be enabled for the crate to compile." } #[cfg(feature = "alloc")] extern crate alloc; +extern crate ndarray as nd; #[macro_use] pub(crate) mod macros { @@ -56,15 +59,27 @@ pub mod math { mod unary; } +pub mod ops { + //! composable operators for tensor manipulations and transformations, neural networks, and + //! more + #[allow(unused_imports)] + #[doc(inline)] + pub use self::{binary::*, unary::*}; + + mod binary; + mod unary; +} + pub mod tensor { #[doc(inline)] - pub use self::{fill::*, like::*, linalg::*, ndtensor::*, shape::*}; + pub use self::{dimensionality::*, fill::*, like::*, linalg::*, ndtensor::*, reshape::*}; + mod dimensionality; mod fill; mod like; mod linalg; mod ndtensor; - mod shape; + mod reshape; } // re-exports diff --git a/traits/src/math/gradient.rs b/traits/src/math/gradient.rs index 51a4573b..d1ed895c 100644 --- a/traits/src/math/gradient.rs +++ b/traits/src/math/gradient.rs @@ -6,11 +6,10 @@ /// the [`Gradient`] trait defines the gradient of a function, which is a function that /// takes an input and returns a delta, which is the change in the output with respect to /// the input. -pub trait Gradient { - type Elem; - type Delta<_S, _D>; +pub trait Gradient { + type Delta<_U>; - fn grad(&self, rhs: &Self::Delta) -> Self::Delta; + fn grad(&self, rhs: &Self::Delta) -> Self::Delta; } /// A trait declaring basic gradient-related routines for a neural network diff --git a/traits/src/ops/binary.rs b/traits/src/ops/binary.rs new file mode 100644 index 00000000..35f152e6 --- /dev/null +++ b/traits/src/ops/binary.rs @@ -0,0 +1,5 @@ +/* + Appellation: binary + Created At: 2025.12.09:07:27:17 + Contrib: @FL03 +*/ diff --git a/traits/src/ops/unary.rs b/traits/src/ops/unary.rs new file mode 100644 index 00000000..8b208c2f --- /dev/null +++ b/traits/src/ops/unary.rs @@ -0,0 +1,64 @@ +/* + Appellation: unary + Created At: 2025.12.09:07:26:04 + Contrib: @FL03 +*/ +/// [`Decrement`] is a chainable trait that defines a decrement method, +/// effectively removing a single unit from the original object to create another +pub trait Decrement { + type Output; + + fn dec(self) -> Self::Output; +} + +/// The [`DecrementMut`] trait defines a decrement method that operates in place, +/// modifying the original object. +pub trait DecrementMut { + fn dec_mut(&mut self); +} +/// The [`Increment`] +pub trait Increment { + type Output; + + fn inc(self) -> Self::Output; +} + +pub trait IncrementMut { + fn inc_mut(&mut self); +} + +/* + ************* Implementations ************* +*/ +use num_traits::One; + +impl Decrement for T +where + T: One + core::ops::Sub, +{ + type Output = T; + + fn dec(self) -> Self::Output { + self - T::one() + } +} + +impl DecrementMut for T +where + T: One + core::ops::SubAssign, +{ + fn dec_mut(&mut self) { + *self -= T::one() + } +} + +impl Increment for T +where + T: One + core::ops::Add, +{ + type Output = T; + + fn inc(self) -> Self::Output { + self + T::one() + } +} diff --git a/traits/src/predict.rs b/traits/src/predict.rs index e41d87bb..4fa297de 100644 --- a/traits/src/predict.rs +++ b/traits/src/predict.rs @@ -17,7 +17,7 @@ pub trait Predict { private!(); - fn predict(&self, input: &Rhs) -> Option; + fn predict(&self, input: &Rhs) -> Self::Output; } /// The [`PredictWithConfidence`] trait is an extension of the [`Predict`] trait, providing @@ -43,7 +43,7 @@ where seal!(); - fn predict(&self, input: &U) -> Option { + fn predict(&self, input: &U) -> Self::Output { self.forward(input) } } @@ -58,7 +58,7 @@ where fn predict_with_confidence(&self, input: &U) -> Option<(Self::Output, Self::Confidence)> { // Get the base prediction - let prediction = Predict::predict(self, input)?; + let prediction = Predict::predict(self, input); let shape = prediction.shape(); // Calculate confidence as the inverse of the variance of the output // For each sample, compute the variance across the output dimensions @@ -73,7 +73,7 @@ where } // Average variance across the batch - let avg_variance = variance_sum / A::from_usize(batch_size).unwrap(); + let avg_variance = variance_sum / A::from_usize(batch_size)?; // Confidence: inverse of variance (clipped to avoid division by zero) let confidence = (A::one() + avg_variance).recip(); diff --git a/traits/src/propagation.rs b/traits/src/propagation.rs index a2f6a14b..d1840e2e 100644 --- a/traits/src/propagation.rs +++ b/traits/src/propagation.rs @@ -22,25 +22,32 @@ pub enum PropagationError { /// step in a neural network or machine learning model. pub trait Backward { type Elem; + + fn backward(&mut self, input: &X, delta: &Delta, gamma: Self::Elem); +} + +pub trait BackwardStep { + type Data<_X>; + type Grad<_X>; type Output; - fn backward(&mut self, input: &X, delta: &Delta, gamma: Self::Elem) -> Option; + fn backward(&mut self, input: &Self::Data, delta: &Self::Grad, gamma: T) -> Self::Output; } -/// The [`Forward`] trait defines an interface that is used to perform a single forward step -/// within a neural network or machine learning model. +/// The [`Forward`] trait describes a common interface for objects designated to perform a +/// single forward step in a neural network or machine learning model. pub trait Forward { type Output; /// a single forward step - fn forward(&self, input: &Rhs) -> Option; + fn forward(&self, input: &Rhs) -> Self::Output; /// this method enables the forward pass to be generically _activated_ using some closure. /// This is useful for isolating the logic of the forward pass from that of the activation /// function and is often used by layers and models. - fn forward_then(&self, input: &Rhs, then: F) -> Option + fn forward_then(&self, input: &Rhs, then: F) -> Self::Output where F: FnOnce(Self::Output) -> Self::Output, { - self.forward(input).map(then) + then(self.forward(input)) } } @@ -49,26 +56,30 @@ pub trait Forward { */ use ndarray::linalg::Dot; -use ndarray::{ArrayBase, Data, Dimension, LinalgScalar}; -use num_traits::FromPrimitive; +use ndarray::{Array, ArrayBase, ArrayView, Data, DataMut, Dimension}; +use num_traits::Num; -impl Backward for ArrayBase +impl Backward, ArrayBase> + for ArrayBase where - A: LinalgScalar + FromPrimitive, + A: 'static + Copy + Num, D: Dimension, - S: ndarray::DataMut, - Dx: core::ops::Mul, - for<'a> X: Dot, - for<'a> &'a Self: core::ops::Add<&'a Dx, Output = Self>, + S: DataMut, + D1: Dimension, + D2: Dimension, + S1: Data, + S2: Data, + for<'b> &'b ArrayBase: Dot, Output = Array>, { type Elem = A; - type Output = (); - fn backward(&mut self, input: &X, delta: &Y, gamma: Self::Elem) -> Option { - let dx = input.dot(delta); - let next = &*self + &(dx * gamma); - self.assign(&next); - Some(()) + fn backward( + &mut self, + input: &ArrayBase, + delta: &ArrayBase, + gamma: Self::Elem, + ) { + self.scaled_add(gamma, &input.dot(&delta.t())) } } @@ -81,7 +92,7 @@ where { type Output = Y; - fn forward(&self, input: &X) -> Option { - Some(input.dot(self)) + fn forward(&self, input: &X) -> Self::Output { + input.dot(self) } } diff --git a/traits/src/rounding.rs b/traits/src/rounding.rs index 92f9c340..36affa7a 100644 --- a/traits/src/rounding.rs +++ b/traits/src/rounding.rs @@ -3,24 +3,7 @@ Created At: 2025.11.26:12:09:08 Contrib: @FL03 */ -use num_traits::{Float, Num}; - -/// divide two values and round down to the nearest integer. -fn floor_div(numerator: T, denom: T) -> T -where - T: Copy + core::ops::Div + core::ops::Rem + core::ops::Sub, -{ - (numerator - (numerator % denom)) / denom -} - -/// Round the given value to the given number of decimal places. -fn round_to(val: T, decimals: usize) -> T -where - T: Float, -{ - let factor = T::from(10).expect("").powi(decimals as i32); - (val * factor).round() / factor -} +use num_traits::Float; pub trait FloorDiv { type Output; @@ -34,12 +17,12 @@ pub trait RoundTo { impl FloorDiv for T where - T: Copy + Num, + T: Copy + core::ops::Div + core::ops::Rem + core::ops::Sub, { type Output = T; fn floor_div(self, rhs: Self) -> Self::Output { - floor_div(self, rhs) + (self - (self % rhs)) / rhs } } @@ -48,6 +31,7 @@ where T: Float, { fn round_to(&self, places: usize) -> Self { - round_to(*self, places) + let factor = T::from(10).unwrap().powi(places as i32); + (*self * factor).round() / factor } } diff --git a/traits/src/tensor/dimensionality.rs b/traits/src/tensor/dimensionality.rs new file mode 100644 index 00000000..2217e87f --- /dev/null +++ b/traits/src/tensor/dimensionality.rs @@ -0,0 +1,32 @@ +/* + Appellation: shape + Created At: 2025.11.26:13:10:09 + Contrib: @FL03 +*/ + +/// the [`Dim`] trait is used to define a type that can be used as a raw dimension. +/// This trait is primarily used to provide abstracted, generic interpretations of the +/// dimensions of the [`ndarray`] crate to ensure long-term compatibility. +pub trait RawDimension { + type Shape; + + private! {} +} + +pub trait Dim: RawDimension { + /// returns the total number of elements considered by the dimension + fn size(&self) -> usize; +} + +/* + ************* Implementations ************* +*/ + +impl RawDimension for D +where + D: nd::Dimension, +{ + type Shape = D::Pattern; + + seal! {} +} diff --git a/traits/src/tensor/linalg.rs b/traits/src/tensor/linalg.rs index e02aec70..63b14866 100644 --- a/traits/src/tensor/linalg.rs +++ b/traits/src/tensor/linalg.rs @@ -24,10 +24,10 @@ pub trait MatMul { fn matmul(&self, rhs: &Rhs) -> Self::Output; } -/// The [`MatPow`] trait defines an interface for computing the exponentiation of a matrix. +/// The [`MatPow`] trait defines an interface for computing the power of some matrix pub trait MatPow { type Output; - /// raise the tensor to the power of the right-hand side, producing some [`Output`](Matpow::Output) + fn matpow(&self, rhs: Rhs) -> Self::Output; } diff --git a/traits/src/tensor/shape.rs b/traits/src/tensor/reshape.rs similarity index 62% rename from traits/src/tensor/shape.rs rename to traits/src/tensor/reshape.rs index 7e6c066b..41a5b90c 100644 --- a/traits/src/tensor/shape.rs +++ b/traits/src/tensor/reshape.rs @@ -1,36 +1,32 @@ /* - Appellation: shape + Appellation: reshape Created At: 2025.11.26:13:10:09 Contrib: @FL03 */ -/// the [`RawDimension`] trait is used to define a type that can be used as a raw dimension. -/// This trait is primarily used to provide abstracted, generic interpretations of the -/// dimensions of the [`ndarray`] crate to ensure long-term compatibility. -pub trait Dim { - private! {} +/// The [`Unsqueeze`] trait establishes an interface for a routine that _unsqueezes_ an array, +/// by inserting a new axis at a specified position. This is useful for reshaping arrays to +/// meet specific dimensional requirements. +pub trait Unsqueeze { + type Output; + + fn unsqueeze(self, axis: usize) -> Self::Output; } -/// The [`DecrementAxis`] trait defines a method enabling an axis to decrement itself, +/// The [`DecrementAxis`] is used as a unary operator for removing a single axis +/// from a multidimensional array or tensor-like structure. pub trait DecrementAxis { type Output; - fn dec(&self) -> Self::Output; + fn dec_axis(&self) -> Self::Output; } + /// The [`IncrementAxis`] trait defines a method enabling an axis to increment itself, /// effectively adding a new axis to the array. pub trait IncrementAxis { type Output; - fn inc(&self) -> Self::Output; -} -/// The [`Unsqueeze`] trait establishes an interface for a routine that _unsqueezes_ an array, -/// by inserting a new axis at a specified position. This is useful for reshaping arrays to -/// meet specific dimensional requirements. -pub trait Unsqueeze { - type Output; - - fn unsqueeze(self, axis: usize) -> Self::Output; + fn inc_axis(self) -> Self::Output; } /* @@ -38,25 +34,31 @@ pub trait Unsqueeze { */ use ndarray::{ArrayBase, Axis, Dimension, RawData, RawDataClone, RemoveAxis}; -impl Dim for D +impl DecrementAxis for D where - D: ndarray::Dimension, + D: RemoveAxis, + E: Dimension, { - seal! {} + type Output = E; + + fn dec_axis(&self) -> Self::Output { + self.remove_axis(Axis(self.ndim() - 1)) + } } -impl DecrementAxis for D +impl IncrementAxis for D where - D: RemoveAxis, + D: Dimension, + E: Dimension, { - type Output = D::Smaller; + type Output = E; - fn dec(&self) -> Self::Output { - self.remove_axis(Axis(self.ndim() - 1)) + fn inc_axis(self) -> Self::Output { + self.insert_axis(Axis(self.ndim())) } } -impl Unsqueeze for ArrayBase +impl Unsqueeze for ArrayBase where D: Dimension, S: RawData, @@ -68,7 +70,7 @@ where } } -impl Unsqueeze for &ArrayBase +impl Unsqueeze for &ArrayBase where D: Dimension, S: RawDataClone,