diff --git a/Cargo.lock b/Cargo.lock index de83507..ebad646 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "allocator-api2" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" + [[package]] name = "anstream" version = "0.6.18" @@ -70,6 +76,21 @@ dependencies = [ "wit-component", ] +[[package]] +name = "augurs-outlier" +version = "0.10.1" +source = "git+https://github.com/grafana/augurs?branch=rand-0.9#631ce6ae7abf3724416248e92595833b37ce1b0e" +dependencies = [ + "itertools 0.14.0", + "roots", + "rustc-hash", + "rv", + "serde", + "thiserror", + "tinyvec", + "tracing", +] + [[package]] name = "autocfg" version = "1.4.0" @@ -93,6 +114,12 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36" +[[package]] +name = "cfg-if" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2fd1289c04a9ea8cb22300a459a72a385d7c73d3259e2ed7dcb2af674838cfa9" + [[package]] name = "clap" version = "4.5.48" @@ -151,6 +178,12 @@ version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" +[[package]] +name = "doc-comment" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fea41bba32d969b513997752735605054bc0dfa92b4c56bf1189f2e174be7a10" + [[package]] name = "either" version = "1.15.0" @@ -187,6 +220,48 @@ dependencies = [ "wit-component", ] +[[package]] +name = "example-outlier" +version = "0.0.2" +dependencies = [ + "augurs-outlier", + "getrandom", + "wit-bindgen", + "wit-component", +] + +[[package]] +name = "example-records" +version = "0.0.2" +dependencies = [ + "wit-bindgen", + "wit-component", +] + +[[package]] +name = "example-resources" +version = "0.0.2" +dependencies = [ + "wit-bindgen", + "wit-component", +] + +[[package]] +name = "example-resources-simple" +version = "0.0.2" +dependencies = [ + "wit-bindgen", + "wit-component", +] + +[[package]] +name = "example-tuples" +version = "0.0.2" +dependencies = [ + "wit-bindgen", + "wit-component", +] + [[package]] name = "foldhash" version = "0.1.4" @@ -304,6 +379,18 @@ dependencies = [ "syn", ] +[[package]] +name = "getrandom" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26145e563e54f2cadc477553f1ec5ee650b00862f0a58bcd12cbdc5f0ea2d2f4" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasi", +] + [[package]] name = "glob" version = "0.3.3" @@ -316,6 +403,8 @@ version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ + "allocator-api2", + "equivalent", "foldhash", ] @@ -365,6 +454,24 @@ version = "1.70.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b192c782037fadd9cfa75548310488aabdbf3d2da73885b31bd0abd03351285" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.14" @@ -383,12 +490,27 @@ version = "0.2.176" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "58f929b4d672ea937a23a1ab494143d968337a5f47e56d0815df1e0890ddf174" +[[package]] +name = "libm" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9fbbcab51052fe104eb5e5d351cf728d30a5be1fe14d9be8a3b097481fb97de" + [[package]] name = "log" version = "0.4.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04cbf5b083de1c7e0222a7a51dbfdba1cbe1c6ab0b15e29fff3f6c077fd9cd9f" +[[package]] +name = "lru" +version = "0.12.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "234cf4f4a04dc1f57e24b96cc0cd600cf2af460d4161ac5ecdd0af8e1f3b2a38" +dependencies = [ + "hashbrown", +] + [[package]] name = "memchr" version = "2.7.4" @@ -401,6 +523,80 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + [[package]] name = "once_cell" version = "1.20.3" @@ -429,6 +625,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "ppv-lite86" +version = "0.2.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85eae3c4ed2f50dcfe72643da4befc30deadb458a9b590d720cde2f2b1e97da9" +dependencies = [ + "zerocopy", +] + [[package]] name = "prettyplease" version = "0.2.29" @@ -457,6 +662,51 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99d9a13982dcf210057a8a78572b2217b667c3beacbf3a0d8b454f6f82837d38" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a8615d50dcf34fa31f7ab52692afec947c4dd0ab803cc87cb3b0b4570ff7463" +dependencies = [ + "num-traits", + "rand", +] + [[package]] name = "rayon" version = "1.11.0" @@ -483,6 +733,33 @@ version = "1.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba39f3699c378cd8970968dcbff9c43159ea4cfbd88d43c00b22f2ef10a435d2" +[[package]] +name = "roots" +version = "0.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "082f11ffa03bbef6c2c6ea6bea1acafaade2fd9050ae0234ab44a2153742b058" + +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + +[[package]] +name = "rv" +version = "0.18.1" +source = "git+https://github.com/sd2k/rv?branch=bump-rand-dependency#ecc65506404da9d2def3f2b53a83c9de9f7c8401" +dependencies = [ + "doc-comment", + "itertools 0.13.0", + "lru", + "num", + "num-traits", + "rand", + "rand_distr", + "special", +] + [[package]] name = "ryu" version = "1.0.19" @@ -599,6 +876,15 @@ dependencies = [ "anstream", ] +[[package]] +name = "special" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b89cf0d71ae639fdd8097350bfac415a41aabf1d5ddd356295fdc95f09760382" +dependencies = [ + "libm", +] + [[package]] name = "strsim" version = "0.11.1" @@ -616,6 +902,41 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f63587ca0f12b72a0600bcba1d40081f830876000bb46dd2337a3051618f4fc8" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "2.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ff15c8ecd7de3849db632e14d18d2571fa09dfc5ed93479bc4485c7a517c913" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tinyvec" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa5fdc3bce6191a1dbc8c02d5c8bffcf557bafa17c124c5264a458f1b0613fa" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + [[package]] name = "toml_datetime" version = "0.7.2" @@ -655,6 +976,37 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d163a63c116ce562a22cda521fcc4d79152e7aba014456fb5eb442f6d6a10109" +[[package]] +name = "tracing" +version = "0.1.41" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "784e0ac535deb450455cbfa28a6f0df145ea1bb7ae51b821cf5e7927fdcfbdd0" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81383ab64e72a7a8b8e13130c49e3dab29def6d0c7d76a03087b3cf71c5c6903" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "tracing-core" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9d12581f227e93f094d3af2ae690a574abb8a2b9b7a96e7cfe9647b2b617678" +dependencies = [ + "once_cell", +] + [[package]] name = "trycmd" version = "0.15.10" @@ -700,6 +1052,24 @@ dependencies = [ "libc", ] +[[package]] +name = "wasi" +version = "0.14.7+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "883478de20367e224c0090af9cf5f9fa85bed63a95c1abf3afc5c083ebc06e8c" +dependencies = [ + "wasip2", +] + +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "wasm-encoder" version = "0.239.0" @@ -906,3 +1276,23 @@ dependencies = [ "unicode-xid", "wasmparser", ] + +[[package]] +name = "zerocopy" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0894878a5fa3edfd6da3f88c4805f4c8558e2b996227a3d864f47fe11e38282c" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.8.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d2b8d9c68ad2b9e4340d7832716a4d21a22a1154777ad56ea55c51a9cf3831" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index c57e357..c96e466 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,23 @@ [workspace] resolver = "3" members = ["cmd/*", "examples/*"] + +[patch.crates-io] +# # meanwhile, had to fork to add a feature flag for timers +# argmin = { git = "https://github.com/sd2k/argmin", branch = "timer-feature" } +# argmin-math = { git = "https://github.com/sd2k/argmin", branch = "timer-feature" } + +# # Depends on rv PR and release (see below). +# changepoint = { git = "https://github.com/sd2k/changepoint", branch = "remove-rayon" } +# # See PR at https://github.com/promised-ai/rv/pull/48. +rv = { git = "https://github.com/sd2k/rv", branch = "bump-rand-dependency" } + +# # See PR at https://github.com/statrs-dev/statrs/pull/331. +# statrs = { git = "https://github.com/RobertJacobsonCDC/statrs", branch = "update_rand_to_09" } + +# # Depends on cap-rand PR/release (see below). +# wasmtime = { git = "https://github.com/sd2k/wasmtime", branch = "bump-rand" } +# wasmtime-wasi = { git = "https://github.com/sd2k/wasmtime", branch = "bump-rand" } +# wasmtime-wasi-io = { git = "https://github.com/sd2k/wasmtime", branch = "bump-rand" } +# # See PR at https://github.com/bytecodealliance/cap-std/pull/394 +# cap-rand = { git = "https://github.com/sd2k/cap-std", branch = "bump-rand" } diff --git a/cmd/gravity/src/codegen/bindings.rs b/cmd/gravity/src/codegen/bindings.rs index 642de34..2cdd788 100644 --- a/cmd/gravity/src/codegen/bindings.rs +++ b/cmd/gravity/src/codegen/bindings.rs @@ -56,7 +56,7 @@ impl<'a> Bindings<'a> { pub fn generate(&mut self) { let (imports, chains) = self.generate_imports(); self.generate_factory(&imports, chains); - self.generate_exports(&imports.instance_name); + self.generate_exports(&imports.instance_name, &imports); } /// Generates the imports for the bindings. @@ -89,12 +89,13 @@ impl<'a> Bindings<'a> { /// /// Note: for now this only generates functions; types and interfaces are /// still TODO - fn generate_exports(&mut self, instance: &GoIdentifier) { + fn generate_exports(&mut self, instance: &GoIdentifier, analyzed_imports: &AnalyzedImports) { let config = ExportConfig { instance, world: self.world, resolve: self.resolve, sizes: self.sizes, + analyzed_imports, }; ExportGenerator::new(config).format_into(&mut self.out) } diff --git a/cmd/gravity/src/codegen/exports.rs b/cmd/gravity/src/codegen/exports.rs index 9005aba..ff2d66e 100644 --- a/cmd/gravity/src/codegen/exports.rs +++ b/cmd/gravity/src/codegen/exports.rs @@ -1,6 +1,10 @@ use genco::prelude::*; -use wit_bindgen_core::wit_parser::{Function, Resolve, SizeAlign, World, WorldItem}; +use wit_bindgen_core::wit_parser::{ + Function, FunctionKind, Resolve, SizeAlign, TypeDefKind, World, WorldItem, +}; +use crate::codegen::ir::AnalyzedImports; +use crate::go::imports::{WAZERO_API_DECODE_U32, WAZERO_API_MODULE}; use crate::go::{GoIdentifier, GoResult, GoType, imports::CONTEXT_CONTEXT}; pub struct ExportConfig<'a> { @@ -8,6 +12,7 @@ pub struct ExportConfig<'a> { pub world: &'a World, pub resolve: &'a Resolve, pub sizes: &'a SizeAlign, + pub analyzed_imports: &'a AnalyzedImports, } pub struct ExportGenerator<'a> { @@ -31,25 +36,108 @@ impl<'a> ExportGenerator<'a> { /// `wit_bindgen_core::abi::call` function. This will call `Func::emit` lots of /// times, one for each instruction in the function, and `Func::emit` will generate /// Go code for each instruction - fn generate_function(&self, func: &Function, tokens: &mut Tokens) { + fn generate_function( + &self, + func: &Function, + interface_name: &str, + is_interface_export: bool, + tokens: &mut Tokens, + ) { + use wit_bindgen_core::wit_parser::{Handle, Type, TypeDefKind}; + + // Validate: we only support borrow parameters, not owned + // This simplifies lifecycle management dramatically + for (param_name, wit_type) in &func.params { + if let Type::Id(type_id) = wit_type { + if let Some(type_def) = self.config.resolve.types.get(*type_id) { + if let TypeDefKind::Handle(Handle::Own(_)) = type_def.kind { + panic!( + "Function '{}' has owned resource parameter '{}'. \n\ + Gravity only supports borrow parameters in exports to simplify resource lifecycle management.\n\ + Owned parameters would require complex state tracking (generation counters, state machines, etc.)\n\ + \n\ + To fix:\n\ + - Change 'func(resource: foo)' to 'func(resource: borrow)'\n\ + - Owned returns (func() -> foo) are still supported!\n\ + \n\ + This design keeps host resource management simple and explicit while still allowing\n\ + guests to create and return resources to the host.", + func.name, param_name + ); + } + } + } + } + let params = func .params .iter() - .map( - |(name, wit_type)| match crate::resolve_type(wit_type, self.config.resolve) { + .map(|(name, wit_type)| { + let resolved = crate::resolve_type(wit_type, self.config.resolve); + let prefixed = self.resolve_type_with_interface(&resolved, interface_name); + match prefixed { GoType::ValueOrOk(t) => (GoIdentifier::local(name), *t), t => (GoIdentifier::local(name), t), - }, - ) + } + }) .collect::>(); let result = if let Some(wit_type) = &func.result { - GoResult::Anon(crate::resolve_type(wit_type, self.config.resolve)) + let resolved = crate::resolve_type(wit_type, self.config.resolve); + GoResult::Anon(self.resolve_type_with_interface(&resolved, interface_name)) } else { GoResult::Empty }; - let mut f = crate::Func::export(result, self.config.sizes); + // Detect if function has resource parameters or returns + let resource_info = self.detect_resource_in_function(func); + + let mut f = if let Some((res_name, _)) = &resource_info { + crate::Func::export_with_resource( + result, + self.config.sizes, + interface_name.to_string(), + res_name.clone(), + ) + } else { + crate::Func::export(result, self.config.sizes) + }; + // Build the full qualified export name for the wasm function call + let qualified_name = if is_interface_export && !interface_name.is_empty() { + // Get the full interface name with package + let full_interface_name = self + .config + .world + .exports + .iter() + .find_map(|(key, _)| { + if let wit_bindgen_core::wit_parser::WorldKey::Interface(id) = key { + let iface = &self.config.resolve.interfaces[*id]; + if let Some(name) = &iface.name { + let short_name = name.split('/').last().unwrap_or(name); + if short_name == interface_name { + if let Some(package_id) = iface.package { + let package = &self.config.resolve.packages[package_id]; + return Some(format!( + "{}:{}/{}", + package.name.namespace, package.name.name, name + )); + } + } + } + } + None + }) + .unwrap_or_else(|| interface_name.to_string()); + format!("{}#{}", full_interface_name, func.name) + } else { + func.name.clone() + }; + + // Set the qualified export name before generating the body + // This avoids string replacement which breaks genco's import tracking + f.set_export_name(qualified_name); + wit_bindgen_core::abi::call( self.config.resolve, wit_bindgen_core::abi::AbiVariant::GuestExport, @@ -67,9 +155,46 @@ impl<'a> ExportGenerator<'a> { .map(|(arg, (param, _))| (arg, param)) .collect::>(); let fn_name = &GoIdentifier::public(&func.name); + + // Collect resource type parameters for instance receiver + let mut type_params = Vec::new(); + for interface in &self.config.analyzed_imports.interfaces { + let interface_name = interface.name.split('/').last().unwrap_or(&interface.name); + for method in &interface.methods { + if method.name.contains("[constructor]") { + if let Some(ret) = &method.return_type { + if let crate::go::GoType::OwnHandle(name) + | crate::go::GoType::BorrowHandle(name) + | crate::go::GoType::Resource(name) = &ret.go_type + { + let prefixed_name = format!("{}-{}", interface_name, name); + let pointer_interface_name = + GoIdentifier::public(format!("p-{}", prefixed_name)); + let value_type_param = + GoIdentifier::public(format!("t-{}-value", prefixed_name)); + let pointer_type_param = + GoIdentifier::public(format!("p-t-{}", prefixed_name)); + type_params.push(( + value_type_param, + pointer_type_param, + pointer_interface_name, + )); + } + } + } + } + } + + // Build receiver with or without type parameters + let receiver = if !type_params.is_empty() { + quote!(*$(self.config.instance)[$(for (value_param, pointer_param, _) in &type_params join (, ) => $value_param, $pointer_param)]) + } else { + quote!(*$(self.config.instance)) + }; + quote_in! { *tokens => $['\n'] - func (i *$(self.config.instance)) $fn_name( + func (i $receiver) $fn_name( $['\r'] ctx $CONTEXT_CONTEXT, $(for (name, typ) in ¶ms join ($['\r']) => $name $typ,) @@ -79,15 +204,599 @@ impl<'a> ExportGenerator<'a> { } } } + + /// Generate exports for a resource interface + fn generate_interface_exports( + &self, + interface_id: wit_bindgen_core::wit_parser::InterfaceId, + tokens: &mut Tokens, + ) { + let interface = &self.config.resolve.interfaces[interface_id]; + let interface_name = interface + .name + .as_ref() + .and_then(|n| n.split('/').last()) + .unwrap_or("unknown"); + + // Check if this interface is also imported (host-provided resources) + // If so, we don't generate impl structs, constructors, or methods for exports + // because the host already manages these resources + let is_imported = self.is_interface_imported(interface_id); + + // Only generate impl structs for resources in interfaces that are NOT imported + if !is_imported { + // Find resources in this interface + for &type_id in interface.types.values() { + let type_def = &self.config.resolve.types[type_id]; + + if matches!(type_def.kind, TypeDefKind::Resource) { + if let Some(resource_name) = &type_def.name { + // Generate the resource implementation struct + self.generate_resource_impl_struct(interface_name, resource_name, tokens); + } + } + } + } + + // Generate methods on the instance for each function + for func in interface.functions.values() { + match &func.kind { + FunctionKind::Constructor(_) => { + // Skip constructors for imported interfaces (host provides them) + if !is_imported { + // Get the resource name + let resource_def = + &self.config.resolve.types[func.kind.resource().unwrap()]; + if let Some(resource_name) = &resource_def.name { + self.generate_constructor_method( + func, + interface_name, + resource_name, + tokens, + ); + } + } + } + FunctionKind::Method(_) => { + // Skip methods for imported interfaces (host provides them) + if !is_imported { + let resource_def = + &self.config.resolve.types[func.kind.resource().unwrap()]; + if let Some(resource_name) = &resource_def.name { + self.generate_resource_method( + func, + interface_name, + resource_name, + tokens, + ); + } + } + } + FunctionKind::Static(_) => { + // TODO: Handle static resource methods + } + FunctionKind::Freestanding => { + // Regular function export (interface-level) + self.generate_function(func, interface_name, true, tokens); + } + FunctionKind::AsyncFreestanding + | FunctionKind::AsyncMethod(_) + | FunctionKind::AsyncStatic(_) => { + // TODO: Handle async functions + } + } + } + } + + /// Generate the implementation struct for a resource + fn generate_resource_impl_struct( + &self, + interface_name: &str, + resource_name: &str, + tokens: &mut Tokens, + ) { + let prefixed_name = format!("{}-{}", interface_name, resource_name); + let impl_name = GoIdentifier::private(format!("{}-impl", prefixed_name)); + let handle_type = GoIdentifier::private(format!("{}-handle", prefixed_name)); + + // Exported resources don't need generics - just store the module to call wasm + quote_in! { *tokens => + $['\n'] + type $impl_name struct { + handle $handle_type + module $WAZERO_API_MODULE + } + $['\n'] + } + } + + /// Generate a constructor method on the instance + fn generate_constructor_method( + &self, + func: &Function, + interface_name: &str, + resource_name: &str, + tokens: &mut Tokens, + ) { + let prefixed_name = format!("{}-{}", interface_name, resource_name); + // Create method name with interface prefix: NewTypesAFoo instead of NewFoo + let method_name = GoIdentifier::public(format!("new-{}", prefixed_name)); + let impl_name = GoIdentifier::private(format!("{}-impl", prefixed_name)); + let handle_type = GoIdentifier::private(format!("{}-handle", prefixed_name)); + + // Build parameter list + let params = func + .params + .iter() + .map(|(name, typ)| { + let param_name = GoIdentifier::local(name); + let param_type = crate::resolve_type(typ, self.config.resolve); + (param_name, param_type) + }) + .collect::>(); + + // For constructors, we generate the call manually to properly wrap the result + // Build the full export name + let full_interface_name = self + .config + .world + .exports + .iter() + .find_map(|(key, _)| { + if let wit_bindgen_core::wit_parser::WorldKey::Interface(id) = key { + let iface = &self.config.resolve.interfaces[*id]; + if let Some(name) = &iface.name { + if name == interface_name { + if let Some(package_id) = iface.package { + let package = &self.config.resolve.packages[package_id]; + return Some(format!( + "{}:{}/{}", + package.name.namespace, package.name.name, name + )); + } + } + } + } + None + }) + .unwrap_or_else(|| interface_name.to_string()); + let export_name = format!("{}#[constructor]{}", full_interface_name, resource_name); + + // Check if any parameters are strings and build string lowering code + let has_string_params = params + .iter() + .any(|(_, typ)| matches!(typ, crate::go::GoType::String)); + + // Build call arguments - strings become ptr, len pairs + let call_args = params + .iter() + .flat_map(|(name, typ)| match typ { + crate::go::GoType::String => { + vec![quote!(uint64(ptr_$name)), quote!(uint64(len_$name))] + } + _ => vec![quote!(uint64($name))], + }) + .collect::>(); + + // Collect resource type parameters for instance receiver + let mut type_params = Vec::new(); + for interface in &self.config.analyzed_imports.interfaces { + let iface_name = interface.name.split('/').last().unwrap_or(&interface.name); + for method in &interface.methods { + if method.name.contains("[constructor]") { + if let Some(ret) = &method.return_type { + if let crate::go::GoType::OwnHandle(name) + | crate::go::GoType::BorrowHandle(name) + | crate::go::GoType::Resource(name) = &ret.go_type + { + let res_prefixed_name = format!("{}-{}", iface_name, name); + let pointer_interface_name = + GoIdentifier::public(format!("p-{}", res_prefixed_name)); + let value_type_param = + GoIdentifier::public(format!("t-{}-value", res_prefixed_name)); + let pointer_type_param = + GoIdentifier::public(format!("p-t-{}", res_prefixed_name)); + type_params.push(( + value_type_param, + pointer_type_param, + pointer_interface_name, + )); + } + } + } + } + } + + // Build receiver with or without type parameters + let receiver = if !type_params.is_empty() { + quote!(*$(self.config.instance)[$(for (value_param, pointer_param, _) in &type_params join (, ) => $value_param, $pointer_param)]) + } else { + quote!(*$(self.config.instance)) + }; + + quote_in! { *tokens => + $['\n'] + func (i $receiver) $method_name( + $['\r'] + ctx $CONTEXT_CONTEXT, + $(for (name, typ) in ¶ms join (,$['\r']) => $name $typ), + ) *$(&impl_name) { + $(if has_string_params { + memory := i.module.Memory() + realloc := i.module.ExportedFunction("cabi_realloc") + }) + $(for (name, typ) in ¶ms => + $(match typ { + crate::go::GoType::String => { + ptr_$name, len_$name, err_$name := writeString(ctx, $name, memory, realloc) + if err_$name != nil { + panic(err_$name) + } + } + _ => {} + }) + ) + raw, err := i.module.ExportedFunction($(quoted(&export_name))).Call(ctx$(for arg in &call_args => , $arg)) + if err != nil { + panic(err) + } + handle := $(&handle_type)(raw[0]) + return &$(&impl_name){ + handle: handle, + module: i.module, + } + } + $['\n'] + } + } + + /// Generate a method on the resource implementation struct + fn generate_resource_method( + &self, + func: &Function, + interface_name: &str, + resource_name: &str, + tokens: &mut Tokens, + ) { + let method_name = GoIdentifier::from_resource_function(&func.name); + let prefixed_name = format!("{}-{}", interface_name, resource_name); + let impl_name = GoIdentifier::private(format!("{}-impl", prefixed_name)); + + // Build parameter list (skip first param which is 'self') + let params = func + .params + .iter() + .skip(1) + .map(|(name, typ)| { + let param_name = GoIdentifier::local(name); + let param_type = crate::resolve_type(typ, self.config.resolve); + (param_name, param_type) + }) + .collect::>(); + + // Build return type + let result = if let Some(wit_type) = &func.result { + GoResult::Anon(crate::resolve_type(wit_type, self.config.resolve)) + } else { + GoResult::Empty + }; + + // For methods, we generate the call manually + // Build the full export name + let full_interface_name = self + .config + .world + .exports + .iter() + .find_map(|(key, _)| { + if let wit_bindgen_core::wit_parser::WorldKey::Interface(id) = key { + let iface = &self.config.resolve.interfaces[*id]; + if let Some(name) = &iface.name { + if name == interface_name { + if let Some(package_id) = iface.package { + let package = &self.config.resolve.packages[package_id]; + return Some(format!( + "{}:{}/{}", + package.name.namespace, package.name.name, name + )); + } + } + } + } + None + }) + .unwrap_or_else(|| interface_name.to_string()); + let export_name = format!("{}#{}", full_interface_name, &func.name); + + // Check if any parameters are strings + let has_string_params = params + .iter() + .any(|(_, typ)| matches!(typ, crate::go::GoType::String)); + + // Build call arguments - strings become ptr, len pairs + let call_args = params + .iter() + .flat_map(|(name, typ)| match typ { + crate::go::GoType::String => { + vec![quote!(uint64(ptr_$name)), quote!(uint64(len_$name))] + } + _ => vec![quote!(uint64($name))], + }) + .collect::>(); + + // TODO(#58): Support wasm64 architecture size + // Calculate pointer size for wasm32 (4 bytes for u32 pointers) + let ptr_size = self + .config + .sizes + .size(&wit_bindgen_core::wit_parser::Type::U32) + .size_wasm32(); + + quote_in! { *tokens => + $['\n'] + func (r *$(&impl_name)) $method_name($(for (name, typ) in ¶ms join (, ) => $name $typ)) $(&result) { + $(if has_string_params { + memory := r.module.Memory() + realloc := r.module.ExportedFunction("cabi_realloc") + }) + $(for (name, typ) in ¶ms => + $(match typ { + crate::go::GoType::String => { + ptr_$name, len_$name, err_$name := writeString(context.Background(), $name, memory, realloc) + if err_$name != nil { + panic(err_$name) + } + } + _ => {} + }) + ) + $(match &result { + GoResult::Empty => { + _, err := r.module.ExportedFunction($(quoted(&export_name))).Call(context.Background(), uint64(r.handle)$(for arg in &call_args => , $arg)) + if err != nil { + panic(err) + } + } + GoResult::Anon(GoType::Uint32) => { + raw, err := r.module.ExportedFunction($(quoted(&export_name))).Call(context.Background(), uint64(r.handle)$(for arg in &call_args => , $arg)) + if err != nil { + panic(err) + } + result := $WAZERO_API_DECODE_U32(uint64(raw[0])) + return result + } + GoResult::Anon(GoType::String) => { + // Guest exports return a pointer to where (ptr, len) is stored + raw, err := r.module.ExportedFunction($(quoted(&export_name))).Call(context.Background(), uint64(r.handle)$(for arg in &call_args => , $arg)) + if err != nil { + panic(err) + } + + // The returned i32 is a pointer to a (ptr, len) pair in memory + result_ptr := uint32(raw[0]) + memory := r.module.Memory() + + // Read ptr and len from the returned pointer location + ptr, ok1 := memory.ReadUint32Le(result_ptr + 0) + if !ok1 { + panic("failed to read string ptr") + } + len, ok2 := memory.ReadUint32Le(result_ptr + $ptr_size) + if !ok2 { + panic("failed to read string len") + } + + // Read the actual string data + buf, ok3 := memory.Read(ptr, len) + if !ok3 { + panic("failed to read string data") + } + + return string(buf) + } + GoResult::Anon(_) => { + // Other types - stub for now + panic("unsupported return type for method") + } + }) + } + $['\n'] + } + } + + /// Detect if a function has resource parameters or returns + /// Returns (resource_name, is_param) if found + fn detect_resource_in_function(&self, func: &Function) -> Option<(String, bool)> { + // Check parameters for resources + for (_name, typ) in &func.params { + if let wit_bindgen_core::wit_parser::Type::Id(id) = typ { + let type_def = &self.config.resolve.types[*id]; + match &type_def.kind { + wit_bindgen_core::wit_parser::TypeDefKind::Resource => { + if let Some(resource_name) = &type_def.name { + return Some((resource_name.clone(), true)); + } + } + wit_bindgen_core::wit_parser::TypeDefKind::Handle(handle) => { + // Extract the resource from inside the handle + let resource_id = match handle { + wit_bindgen_core::wit_parser::Handle::Own(id) + | wit_bindgen_core::wit_parser::Handle::Borrow(id) => id, + }; + let resource_def = &self.config.resolve.types[*resource_id]; + if let Some(resource_name) = &resource_def.name { + return Some((resource_name.clone(), true)); + } + } + _ => {} + } + } + } + + // Check result for resources + if let Some(result_type) = &func.result { + if let wit_bindgen_core::wit_parser::Type::Id(id) = result_type { + let type_def = &self.config.resolve.types[*id]; + match &type_def.kind { + wit_bindgen_core::wit_parser::TypeDefKind::Resource => { + if let Some(resource_name) = &type_def.name { + return Some((resource_name.clone(), false)); + } + } + wit_bindgen_core::wit_parser::TypeDefKind::Handle(handle) => { + // Extract the resource from inside the handle + let resource_id = match handle { + wit_bindgen_core::wit_parser::Handle::Own(id) + | wit_bindgen_core::wit_parser::Handle::Borrow(id) => id, + }; + let resource_def = &self.config.resolve.types[*resource_id]; + if let Some(resource_name) = &resource_def.name { + return Some((resource_name.clone(), false)); + } + } + _ => {} + } + } + } + + None + } + + /// Check if an interface is imported (i.e., host-provided) + fn is_interface_imported( + &self, + interface_id: wit_bindgen_core::wit_parser::InterfaceId, + ) -> bool { + use wit_bindgen_core::wit_parser::WorldKey; + + for (key, _) in &self.config.world.imports { + if let WorldKey::Interface(id) = key { + if *id == interface_id { + return true; + } + } + } + false + } + + /// Extract interface name from a function's resource parameters + fn get_resource_interface_name(&self, func: &wit_bindgen_core::wit_parser::Function) -> String { + use wit_bindgen_core::wit_parser::{Handle, Type, TypeDefKind, TypeOwner}; + + // Look through parameters to find a resource handle + for (_, param_type) in &func.params { + if let Type::Id(type_id) = param_type { + let type_def = self + .config + .resolve + .types + .get(*type_id) + .expect("type not found"); + + // Check if it's a handle to a resource + if let TypeDefKind::Handle(handle) = &type_def.kind { + let resource_id = match handle { + Handle::Own(id) | Handle::Borrow(id) => id, + }; + + let resource_def = self + .config + .resolve + .types + .get(*resource_id) + .expect("resource not found"); + + // If this is a type alias (from `use iface.{resource}`), follow the reference + let actual_resource_id = + if let TypeDefKind::Type(Type::Id(id)) = &resource_def.kind { + id + } else { + resource_id + }; + + let actual_resource_def = self + .config + .resolve + .types + .get(*actual_resource_id) + .expect("actual resource not found"); + + // Get the interface that owns this resource + match &actual_resource_def.owner { + TypeOwner::Interface(iface_id) => { + let interface = self + .config + .resolve + .interfaces + .get(*iface_id) + .expect("interface not found"); + if let Some(iface_name) = &interface.name { + let iface_short_name = + iface_name.split('/').last().unwrap_or(iface_name); + return iface_short_name.to_string(); + } + } + TypeOwner::World(_) => { + // This case should have been handled by following the type alias above + } + TypeOwner::None => {} + } + } + } + } + + // No resource found, return empty string + String::new() + } + + fn resolve_type_with_interface(&self, typ: &GoType, interface_name: &str) -> GoType { + match typ { + GoType::OwnHandle(name) | GoType::BorrowHandle(name) => { + let prefixed_name = if interface_name.is_empty() { + format!("{}-handle", name) + } else { + format!("{}-{}-handle", interface_name, name) + }; + GoType::Resource(prefixed_name) + } + GoType::Resource(name) => { + // Check if it's already a handle type (has -handle suffix) + if name.ends_with("-handle") { + typ.clone() + } else if name.contains('-') { + // Already prefixed but needs handle suffix + GoType::Resource(format!("{}-handle", name)) + } else { + let prefixed_name = if interface_name.is_empty() { + format!("{}-handle", name) + } else { + format!("{}-{}-handle", interface_name, name) + }; + GoType::Resource(prefixed_name) + } + } + _ => typ.clone(), + } + } } impl FormatInto for ExportGenerator<'_> { fn format_into(self, tokens: &mut Tokens) { for item in self.config.world.exports.values() { match item { - WorldItem::Function(func) => self.generate_function(func, tokens), - WorldItem::Interface { .. } => todo!("generate interface exports"), - WorldItem::Type(_) => todo!("generate type exports"), + WorldItem::Function(func) => { + // For freestanding functions with resource params, look up the interface for type resolution + // but don't use it as part of the export name (is_interface_export = false) + let interface_name = self.get_resource_interface_name(func); + self.generate_function(func, &interface_name, false, tokens); + } + WorldItem::Interface { id, .. } => { + self.generate_interface_exports(*id, tokens); + } + WorldItem::Type(_) => { + // Type exports are skipped for now + // TODO: Implement type exports + } } } } @@ -100,6 +809,7 @@ mod tests { Function, FunctionKind, Resolve, SizeAlign, Type, World, WorldItem, WorldKey, }; + use crate::codegen::ir::AnalyzedImports; use crate::go::GoIdentifier; use super::{ExportConfig, ExportGenerator}; @@ -135,18 +845,29 @@ mod tests { sizes.fill(&resolve); let instance = GoIdentifier::public("TestInstance"); + let analyzed_imports = AnalyzedImports { + interfaces: vec![], + standalone_types: vec![], + standalone_functions: vec![], + exported_resources: vec![], + factory_name: GoIdentifier::public("TestFactory"), + instance_name: GoIdentifier::public("TestInstance"), + constructor_name: GoIdentifier::public("NewTestFactory"), + }; + let config = ExportConfig { instance: &instance, world: &world, resolve: &resolve, sizes: &sizes, + analyzed_imports: &analyzed_imports, }; let generator = ExportGenerator::new(config); let mut tokens = Tokens::new(); - // Call the actual generate_function method - generator.generate_function(&func, &mut tokens); + // Call the actual generate_function method (world-level, not interface-level) + generator.generate_function(&func, "test-interface", false, &mut tokens); let generated = tokens.to_string().unwrap(); println!("Generated: {}", generated); @@ -166,7 +887,7 @@ mod tests { assert!(generated.contains("if err1 != nil {")); assert!(generated.contains("panic(err1)")); assert!(generated.contains("results1 := raw1[0]")); - assert!(generated.contains("result2 := api.DecodeU32(results1)")); + assert!(generated.contains("result2 := api.DecodeU32(uint64(results1))")); assert!(generated.contains("return result2")); } } diff --git a/cmd/gravity/src/codegen/factory.rs b/cmd/gravity/src/codegen/factory.rs index 50a1ed5..1a2c83d 100644 --- a/cmd/gravity/src/codegen/factory.rs +++ b/cmd/gravity/src/codegen/factory.rs @@ -7,7 +7,7 @@ use crate::{ go::{ GoIdentifier, comment, imports::{ - CONTEXT_CONTEXT, ERRORS_NEW, WAZERO_API_MEMORY, WAZERO_API_MODULE, + CONTEXT_CONTEXT, ERRORS_NEW, SYNC_MUTEX, WAZERO_API_MEMORY, WAZERO_API_MODULE, WAZERO_COMPILED_MODULE, WAZERO_NEW_MODULE_CONFIG, WAZERO_NEW_RUNTIME, WAZERO_RUNTIME, }, }, @@ -78,83 +78,448 @@ impl<'a> FactoryGenerator<'a> { .. } = &self.config.analyzed_imports; let wasm_var_name = self.config.wasm_var_name; - // Build the parameter list - let params = self.build_parameters(); - quote_in! { *tokens => - $['\n'] - type $factory_name struct { - runtime $WAZERO_RUNTIME - module $WAZERO_COMPILED_MODULE + + // Get type parameters as structured data + let type_param_data = self.build_parameters_with_generics(); + + // Collect resource info before quote_in! to avoid borrow issues + let resource_info = self.collect_resource_info(); + + // Build interface type parameters + let mut interface_type_params: std::collections::HashMap< + String, + Vec<(GoIdentifier, GoIdentifier)>, + > = std::collections::HashMap::new(); + for (iface_name, _resource_name, prefixed_name, type_param_name) in &resource_info { + let value_type_param = GoIdentifier::public(format!("t-{}-value", prefixed_name)); + let pointer_type_param = + GoIdentifier::public(format!("p-{}", String::from(type_param_name))); + interface_type_params + .entry(iface_name.clone()) + .or_insert_with(Vec::new) + .push((value_type_param, pointer_type_param)); + } + + // Generate resource table type if we have resources + let has_resources = !type_param_data.is_empty(); + + if has_resources { + self.generate_resource_table_type(tokens); + } + + if has_resources { + let interfaces = &self.config.analyzed_imports.interfaces; + + // Pre-build interface parameter tokens completely before quote_in + let mut interface_params = Vec::new(); + for interface in interfaces.iter() { + let interface_name = interface.name.split('/').last().unwrap_or(&interface.name); + if let Some(iface_type_params) = interface_type_params.get(interface_name) { + let type_args: Vec<_> = iface_type_params + .iter() + .flat_map(|(v, p)| vec![v, p]) + .collect(); + let param_tokens = quote!($(&interface.constructor_param_name) $(&interface.go_interface_name)[$(for tp in type_args join (, ) => $tp)],); + interface_params.push(param_tokens); + } else { + let param_tokens = quote!($(&interface.constructor_param_name) $(&interface.go_interface_name),); + interface_params.push(param_tokens); + } } - $['\n'] - func $constructor_name( - $['\r'] - $params - $['\r'] - ) (*$factory_name, error) { - wazeroRuntime := $WAZERO_NEW_RUNTIME(ctx) - - $(for chain in self.config.import_chains.values() => + + quote_in! { *tokens => + $['\n'] + type $factory_name[$(for (value_param, pointer_param, pointer_iface) in &type_param_data join (, ) => $value_param any, $pointer_param $pointer_iface[$value_param])] struct { + runtime $WAZERO_RUNTIME + module $WAZERO_COMPILED_MODULE + $['\r'] + $(for (_iface_name, _resource_name, prefixed_name, type_param_name) in resource_info.iter() join ($['\r']) => + $(&GoIdentifier::public(format!("{}-resource-table", prefixed_name))) *$(&GoIdentifier::private(format!("{}-resource-table", prefixed_name)))[$(&GoIdentifier::public(format!("t-{}-value", prefixed_name))), $(&GoIdentifier::public(format!("p-{}", String::from(type_param_name))))]) + } + $['\n'] + func $constructor_name[$(for (value_param, pointer_param, pointer_iface) in &type_param_data join (, ) => $value_param any, $pointer_param $pointer_iface[$value_param])]( + $['\r'] + ctx $CONTEXT_CONTEXT, + $(for param_tokens in &interface_params join ($['\r']) => $param_tokens) + $['\r'] + ) (*$factory_name[$(for (value_param, pointer_param, _) in &type_param_data join (, ) => $value_param, $pointer_param)], error) { + $['\r'] + wazeroRuntime := $WAZERO_NEW_RUNTIME(ctx) + $['\r'] + $(comment(&["Initialize resource tables before host module instantiation"])) + $(for (_iface_name, _resource_name, prefixed_name, type_param_name) in resource_info.iter() join ($['\r']) => + $(&GoIdentifier::private(format!("{}_resource_table", prefixed_name))) := new$(&GoIdentifier::public(format!("{}-resource-table", prefixed_name)))[$(&GoIdentifier::public(format!("t-{}-value", prefixed_name))), $(&GoIdentifier::public(format!("p-{}", String::from(type_param_name))))]()) + $['\r'] + $['\r'] + $(comment(&["Instantiate import host modules"])) + $(for chain in self.config.import_chains.values() => $chain $['\r'] ) - - $(comment(&[ + $['\r'] + $(comment(&["Instantiate export resource management host modules"])) + $(for chain in self.generate_export_resource_chains().values() => + $chain + $['\r'] + ) + $['\r'] + $(comment(&[ "Compiling the module takes a LONG time, so we want to do it once and hold", "onto it with the Runtime", ])) - module, err := wazeroRuntime.CompileModule(ctx, $wasm_var_name) - if err != nil { - return nil, err + module, err := wazeroRuntime.CompileModule(ctx, $wasm_var_name) + if err != nil { + return nil, err + } + return &$factory_name[$(for (value_param, pointer_param, _) in &type_param_data join (, ) => $value_param, $pointer_param)]{ + runtime: wazeroRuntime, + module: module, + $['\r'] + $(for (_iface_name, _resource_name, prefixed_name, _) in resource_info.iter() => + $(&GoIdentifier::public(format!("{}-resource-table", prefixed_name))): $(&GoIdentifier::private(format!("{}_resource_table", prefixed_name))),$['\r']) + }, nil } - return &$factory_name{ - runtime: wazeroRuntime, - module: module, - }, nil - } $['\n'] - func (f *$factory_name) Instantiate(ctx $CONTEXT_CONTEXT) (*$instance_name, error) { + func (f *$factory_name[$(for (value_param, pointer_param, _) in &type_param_data join (, ) => $value_param, $pointer_param)]) Instantiate(ctx $CONTEXT_CONTEXT) (*$instance_name[$(for (value_param, pointer_param, _) in &type_param_data join (, ) => $value_param, $pointer_param)], error) { if module, err := f.runtime.InstantiateModule(ctx, f.module, $WAZERO_NEW_MODULE_CONFIG()); err != nil { return nil, err } else { - return &$instance_name{module}, nil + return &$instance_name[$(for (value_param, pointer_param, _) in &type_param_data join (, ) => $value_param, $pointer_param)]{ + module: module, + $['\r'] + $(for (_iface_name, _resource_name, prefixed_name, _) in resource_info.iter() => + $(&GoIdentifier::public(format!("{}-resource-table", prefixed_name))): f.$(&GoIdentifier::public(format!("{}-resource-table", prefixed_name))),$['\r']) + }, nil } } $['\n'] - func (f *$factory_name) Close(ctx $CONTEXT_CONTEXT) { + func (f *$factory_name[$(for (value_param, pointer_param, _) in &type_param_data join (, ) => $value_param, $pointer_param)]) Close(ctx $CONTEXT_CONTEXT) { f.runtime.Close(ctx) } $['\n'] - }; + }; + } else { + let interfaces = &self.config.analyzed_imports.interfaces; + quote_in! { *tokens => + $['\n'] + type $factory_name struct { + runtime $WAZERO_RUNTIME + module $WAZERO_COMPILED_MODULE + } + $['\n'] + func $constructor_name( + ctx $CONTEXT_CONTEXT, + $(for interface in interfaces.iter() join ($['\r']) => + $(&interface.constructor_param_name) $(&interface.go_interface_name),) + ) (*$factory_name, error) { + wazeroRuntime := $WAZERO_NEW_RUNTIME(ctx) + + $(comment(&["Instantiate import host modules"])) + $(for chain in self.config.import_chains.values() => + $chain + $['\r'] + ) + + $(comment(&[ + "Compiling the module takes a LONG time, so we want to do it once and hold", + "onto it with the Runtime", + ])) + module, err := wazeroRuntime.CompileModule(ctx, $wasm_var_name) + if err != nil { + return nil, err + } + return &$factory_name{ + runtime: wazeroRuntime, + module: module, + }, nil + } + $['\n'] + func (f *$factory_name) Instantiate(ctx $CONTEXT_CONTEXT) (*$instance_name, error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, $WAZERO_NEW_MODULE_CONFIG()); err != nil { + return nil, err + } else { + return &$instance_name{module}, nil + } + } + $['\n'] + func (f *$factory_name) Close(ctx $CONTEXT_CONTEXT) { + f.runtime.Close(ctx) + } + $['\n'] + }; + } } /// Generate the Instance struct, and methods. fn generate_instance(&self, tokens: &mut Tokens) { let instance_name = &self.config.analyzed_imports.instance_name; - quote_in! { *tokens => - type $instance_name struct { - module $WAZERO_API_MODULE - } - $['\n'] - func (i *$instance_name) Close(ctx $CONTEXT_CONTEXT) error { - if err := i.module.Close(ctx); err != nil { - return err + + // Get type parameter data for resources + let type_param_data = self.build_parameters_with_generics(); + let resource_info = self.collect_resource_info(); + let has_resources = !type_param_data.is_empty(); + + if has_resources { + // Generic instance with resource tables + quote_in! { *tokens => + type $instance_name[$(for (value_param, pointer_param, pointer_iface) in &type_param_data join (, ) => $value_param any, $pointer_param $pointer_iface[$value_param])] struct { + module $WAZERO_API_MODULE + $['\r'] + $(for (_iface_name, _resource_name, prefixed_name, type_param_name) in resource_info.iter() join ($['\r']) => + $(&GoIdentifier::public(format!("{}-resource-table", prefixed_name))) *$(&GoIdentifier::private(format!("{}-resource-table", prefixed_name)))[$(&GoIdentifier::public(format!("t-{}-value", prefixed_name))), $(&GoIdentifier::public(format!("p-{}", String::from(type_param_name))))]) } + $['\n'] + func (i *$instance_name[$(for (value_param, pointer_param, _) in &type_param_data join (, ) => $value_param, $pointer_param)]) Close(ctx $CONTEXT_CONTEXT) error { + if err := i.module.Close(ctx); err != nil { + return err + } - return nil - } - $['\n'] - }; + return nil + } + $['\n'] + }; + } else { + // Non-generic instance (no resources) + quote_in! { *tokens => + type $instance_name struct { + module $WAZERO_API_MODULE + } + $['\n'] + func (i *$instance_name) Close(ctx $CONTEXT_CONTEXT) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil + } + $['\n'] + }; + } } - /// Build parameter list for factory constructor - fn build_parameters(&self) -> Tokens { + /// Build type parameters and parameter list for factory constructor + fn build_parameters_with_generics(&self) -> Vec<(GoIdentifier, GoIdentifier, GoIdentifier)> { + // Use collect_resource_info for consistency + let resource_info = self.collect_resource_info(); + let mut result = Vec::new(); + + for (_interface_name, _resource_name, prefixed_name, type_param_name) in &resource_info { + let pointer_interface_name = GoIdentifier::public(format!("p-{}", prefixed_name)); + let value_type_param = GoIdentifier::public(format!("t-{}-value", prefixed_name)); + let pointer_type_param = + GoIdentifier::public(format!("p-{}", String::from(type_param_name))); + + result.push((value_type_param, pointer_type_param, pointer_interface_name)); + } + + result + } + + /// Collect resource information (interface name, resource name, prefixed name, and type parameter) + fn collect_resource_info(&self) -> Vec<(String, String, String, GoIdentifier)> { + let mut resources = Vec::new(); let interfaces = &self.config.analyzed_imports.interfaces; - quote! { - ctx $CONTEXT_CONTEXT, - $(for interface in interfaces.iter() => - $(&interface.constructor_param_name) $(&interface.go_interface_name),) + for interface in interfaces.iter() { + let interface_name = interface.name.split('/').last().unwrap_or(&interface.name); + for method in &interface.methods { + if method.name.contains("[constructor]") { + if let Some(ret) = &method.return_type { + if let crate::go::GoType::OwnHandle(name) + | crate::go::GoType::BorrowHandle(name) + | crate::go::GoType::Resource(name) = &ret.go_type + { + let prefixed_name = format!("{}-{}", interface_name, name); + let type_param_name = + GoIdentifier::public(format!("t-{}", prefixed_name)); + resources.push(( + interface_name.to_string(), + name.clone(), + prefixed_name, + type_param_name, + )); + } + } + } + } + } + + resources + } + + /// Generate export resource management host functions + /// These are [resource-new] and [resource-drop] functions that WASM calls + /// when it creates/destroys exported resources + fn generate_export_resource_chains(&self) -> BTreeMap> { + let mut chains = BTreeMap::new(); + + // Group exported resources by their wazero module name + let mut resources_by_module: BTreeMap< + String, + Vec<&crate::codegen::ir::ExportedResourceInfo>, + > = BTreeMap::new(); + + for resource in &self.config.analyzed_imports.exported_resources { + resources_by_module + .entry(resource.wazero_export_module_name.clone()) + .or_insert_with(Vec::new) + .push(resource); + } + + // Generate a host module for each export interface + for (i, (module_name, resources)) in resources_by_module.into_iter().enumerate() { + let err = &GoIdentifier::private(format!("err_export{i}")); + let mut chain = quote! { + _, $err := wazeroRuntime.NewHostModuleBuilder($(quoted(&module_name))). + }; + + for resource in resources { + let resource_name = &resource.resource_name; + + // Generate [resource-new]foo function + // For MVP: simple identity mapping (rep -> rep) + chain.push(); + quote_in! { chain => + NewFunctionBuilder(). + WithFunc(func( + ctx $CONTEXT_CONTEXT, + mod $WAZERO_API_MODULE, + rep uint32, + ) uint32 { + $(comment(&[ + &format!("[resource-new]{}: allocate handle for WASM-created resource", resource_name), + "For MVP: using identity mapping (handle = rep)", + ])) + return rep + }). + Export($(quoted(&format!("[resource-new]{}", resource_name)))). + }; + + // Generate [resource-drop]foo function + chain.push(); + quote_in! { chain => + NewFunctionBuilder(). + WithFunc(func( + ctx $CONTEXT_CONTEXT, + mod $WAZERO_API_MODULE, + handle uint32, + ) { + $(comment(&[ + &format!("[resource-drop]{}: cleanup for WASM resource", resource_name), + "For MVP: no-op (WASM manages its own resources)", + ])) + _ = handle + }). + Export($(quoted(&format!("[resource-drop]{}", resource_name)))). + }; + } + + chain.push(); + quote_in! { chain => + Instantiate(ctx) + if $err != nil { + return nil, $err + } + }; + + chains.insert(module_name, chain); + } + + chains + } + + /// Generate the resource table type + fn generate_resource_table_type(&self, tokens: &mut Tokens) { + // Generate a separate table type for each resource + for (interface_name, resource_name, prefixed_name, type_param_name) in + self.collect_resource_info() + { + let table_type_name = + GoIdentifier::private(format!("{}-resource-table", prefixed_name)); + let resource_interface = GoIdentifier::public(&prefixed_name); + let handle_type = GoIdentifier::private(format!("{}-handle", prefixed_name)); + let value_type_param = GoIdentifier::public(format!("t-{}-value", prefixed_name)); + let pointer_type_param = + GoIdentifier::public(format!("p-{}", String::from(&type_param_name))); + let pointer_interface_name = GoIdentifier::public(format!("p-{}", prefixed_name)); + + let pointer_comment = format!( + "{} constrains a pointer to a type implementing the {} interface.", + String::from(&pointer_interface_name), + String::from(&resource_interface) + ); + + let table_comment = format!( + "{} is a resource table for {} resources from the {} interface.", + String::from(&table_type_name), + resource_name, + interface_name + ); + + quote_in! { *tokens => + $['\n'] + $(comment(&[pointer_comment.as_str()])) + type $(&pointer_interface_name)[$(&value_type_param) any] interface { + *$(&value_type_param) + $(&resource_interface) + } + $['\n'] + $(comment(&[table_comment.as_str()])) + type $(&table_type_name)[$(&value_type_param) any, $(&pointer_type_param) $(&pointer_interface_name)[$(&value_type_param)]] struct { + mu $SYNC_MUTEX + nextHandle uint32 + table map[$(&handle_type)]*$(&value_type_param) + } + $['\n'] + func new$(&GoIdentifier::public(format!("{}-resource-table", prefixed_name)))[$(&value_type_param) any, $(&pointer_type_param) $(&pointer_interface_name)[$(&value_type_param)]]() *$(&table_type_name)[$(&value_type_param), $(&pointer_type_param)] { + return &$(&table_type_name)[$(&value_type_param), $(&pointer_type_param)]{ + nextHandle: 1, + table: make(map[$(&handle_type)]*$(&value_type_param)), + } + } + $['\n'] + $(comment(&["Store adds a resource to the table and returns its handle."])) + func (t *$(&table_type_name)[$(&value_type_param), $(&pointer_type_param)]) Store(resource $(&value_type_param)) $(&handle_type) { + t.mu.Lock() + defer t.mu.Unlock() + handle := $(&handle_type)(t.nextHandle) + t.nextHandle++ + t.table[handle] = &resource + return handle + } + $['\n'] + $(comment(&["get returns a pointer to the resource from the table by its handle."])) + func (t *$(&table_type_name)[$(&value_type_param), $(&pointer_type_param)]) get(handle $(&handle_type)) ($(&pointer_type_param), bool) { + t.mu.Lock() + defer t.mu.Unlock() + resource, ok := t.table[handle] + if !ok { + var zero $(&pointer_type_param) + return zero, false + } + return resource, true + } + $['\n'] + $(comment(&["Get retrieves a resource from the table by its handle."])) + func (t *$(&table_type_name)[$(&value_type_param), $(&pointer_type_param)]) Get(handle $(&handle_type)) ($(&value_type_param), bool) { + t.mu.Lock() + defer t.mu.Unlock() + resource, ok := t.table[handle] + if !ok { + var zero $(&value_type_param) + return zero, false + } + return *resource, true + } + $['\n'] + $(comment(&["Remove deletes a resource from the table."])) + func (t *$(&table_type_name)[$(&value_type_param), $(&pointer_type_param)]) Remove(handle $(&handle_type)) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.table, handle) + } + $['\n'] + } } } } @@ -185,6 +550,7 @@ mod tests { interfaces: vec![], standalone_types: vec![], standalone_functions: vec![], + exported_resources: vec![], factory_name: GoIdentifier::public("test-factory"), instance_name: GoIdentifier::public("test-instance"), constructor_name: GoIdentifier::public("test-constructor"), diff --git a/cmd/gravity/src/codegen/func.rs b/cmd/gravity/src/codegen/func.rs index 169cb03..dffedc9 100644 --- a/cmd/gravity/src/codegen/func.rs +++ b/cmd/gravity/src/codegen/func.rs @@ -1,17 +1,20 @@ use std::mem; -use genco::prelude::*; +use genco::{prelude::*, tokens::static_literal}; use wit_bindgen_core::{ abi::{Bindgen, Instruction}, - wit_parser::{Alignment, ArchitectureSize, Resolve, Result_, SizeAlign, Type}, + wit_parser::{ + Alignment, ArchitectureSize, Handle, Resolve, Result_, SizeAlign, Type, TypeDefKind, + }, }; use crate::{ go::{ GoIdentifier, GoResult, GoType, Operand, comment, imports::{ - ERRORS_NEW, WAZERO_API_DECODE_I32, WAZERO_API_DECODE_U32, WAZERO_API_ENCODE_I32, - WAZERO_API_ENCODE_U32, + ERRORS_NEW, REFLECT_VALUE_OF, WAZERO_API_DECODE_F32, WAZERO_API_DECODE_F64, + WAZERO_API_DECODE_I32, WAZERO_API_DECODE_U32, WAZERO_API_ENCODE_F32, + WAZERO_API_ENCODE_F64, WAZERO_API_ENCODE_I32, WAZERO_API_ENCODE_U32, }, }, resolve_type, resolve_wasm_type, @@ -21,16 +24,33 @@ use crate::{ /// /// Functions in the Component Model can be imported into a world or /// exported from a world. + +/// Context for resource handling in imported functions +#[derive(Clone)] +struct ResourceContext { + /// Interface name (e.g., "types-a") + interface_name: String, + /// Resource name (e.g., "foo") + resource_name: String, + /// Resource table variable name (e.g., "typesAFooResourceTable") + table_var: String, +} + enum Direction<'a> { /// The function is imported into the world. Import { /// The name of the parameter representing the interface instance /// in the generated host binding function. param_name: &'a GoIdentifier, + /// Optional resource context for resource constructors and methods + resource_context: Option, }, /// The function is exported from the world. #[allow(dead_code, reason = "halfway through refactor of func bindings")] - Export, + Export { + /// Optional resource context for resource parameters/returns + resource_context: Option, + }, } pub struct Func<'a> { @@ -42,6 +62,8 @@ pub struct Func<'a> { block_storage: Vec>, blocks: Vec<(Tokens, Vec)>, sizes: &'a SizeAlign, + /// Override the export name used in CallWasm instructions (for interface-qualified names) + export_name: Option, } impl<'a> Func<'a> { @@ -49,7 +71,68 @@ impl<'a> Func<'a> { #[allow(dead_code, reason = "halfway through refactor of func bindings")] pub fn export(result: GoResult, sizes: &'a SizeAlign) -> Self { Self { - direction: Direction::Export, + direction: Direction::Export { + resource_context: None, + }, + args: Vec::new(), + result, + tmp: 0, + body: Tokens::new(), + block_storage: Vec::new(), + blocks: Vec::new(), + sizes, + export_name: None, + } + } + + /// Create a new exported function with resource context. + pub fn export_with_resource( + result: GoResult, + sizes: &'a SizeAlign, + interface_name: String, + resource_name: String, + ) -> Self { + let interface_pascal = interface_name + .split('-') + .map(|s| { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } + }) + .collect::>() + .join(""); + let resource_pascal = resource_name + .split('-') + .map(|s| { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } + }) + .collect::>() + .join(""); + + let interface_camel = { + let mut c = interface_pascal.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_lowercase().collect::() + c.as_str(), + } + }; + + let table_var = format!("{}{}ResourceTable", interface_camel, resource_pascal); + + Self { + direction: Direction::Export { + resource_context: Some(ResourceContext { + interface_name, + resource_name, + table_var, + }), + }, args: Vec::new(), result, tmp: 0, @@ -57,13 +140,46 @@ impl<'a> Func<'a> { block_storage: Vec::new(), blocks: Vec::new(), sizes, + export_name: None, } } /// Create a new exported function. pub fn import(param_name: &'a GoIdentifier, result: GoResult, sizes: &'a SizeAlign) -> Self { Self { - direction: Direction::Import { param_name }, + direction: Direction::Import { + param_name, + resource_context: None, + }, + args: Vec::new(), + result, + tmp: 0, + body: Tokens::new(), + block_storage: Vec::new(), + blocks: Vec::new(), + sizes, + export_name: None, + } + } + + /// Create a new imported function with resource context. + pub fn import_with_resource( + param_name: &'a GoIdentifier, + result: GoResult, + sizes: &'a SizeAlign, + interface_name: String, + resource_name: String, + table_var: String, + ) -> Self { + Self { + direction: Direction::Import { + param_name, + resource_context: Some(ResourceContext { + interface_name, + resource_name, + table_var, + }), + }, args: Vec::new(), result, tmp: 0, @@ -71,6 +187,7 @@ impl<'a> Func<'a> { block_storage: Vec::new(), blocks: Vec::new(), sizes, + export_name: None, } } @@ -92,6 +209,11 @@ impl<'a> Func<'a> { &self.body } + /// Set the export name to use in CallWasm instructions (for interface-qualified names) + pub fn set_export_name(&mut self, name: String) { + self.export_name = Some(name); + } + fn push_arg(&mut self, value: &str) { self.args.push(value.into()) } @@ -114,6 +236,11 @@ impl Bindgen for Func<'_> { let iter_element = "e"; let iter_base = "base"; + let payload = &format!("{:?}", inst); + quote_in! { + self.body => + $(comment([payload])) + } match inst { Instruction::GetArg { nth } => { let arg = &format!("arg{nth}"); @@ -138,7 +265,7 @@ impl Bindgen for Func<'_> { let realloc = &format!("realloc{tmp}"); let operand = &operands[0]; match self.direction { - Direction::Export => { + Direction::Export { .. } => { quote_in! { self.body => $['\r'] $memory := i.module.Memory() @@ -186,33 +313,35 @@ impl Bindgen for Func<'_> { let ret = &format!("results{tmp}"); let err = &format!("err{tmp}"); let default = &format!("default{tmp}"); + // Use export_name if set, otherwise use the instruction's name + let export_name = self.export_name.as_deref().unwrap_or(name); // TODO(#17): Wrapping every argument in `uint64` is bad and we should instead be looking // at the types and converting with proper guards in place quote_in! { self.body => $['\r'] $(match &self.result { GoResult::Anon(GoType::ValueOrError(typ)) => { - $raw, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + $raw, $err := i.module.ExportedFunction($(quoted(export_name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) if $err != nil { var $default $(typ.as_ref()) return $default, $err } } GoResult::Anon(GoType::Error) => { - $raw, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + $raw, $err := i.module.ExportedFunction($(quoted(export_name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) if $err != nil { return $err } } GoResult::Anon(_) => { - $raw, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + $raw, $err := i.module.ExportedFunction($(quoted(export_name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) if $err != nil { panic($err) } } GoResult::Empty => { - _, $err := i.module.ExportedFunction($(quoted(*name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) + _, $err := i.module.ExportedFunction($(quoted(export_name))).Call(ctx, $(for op in operands.iter() join (, ) => uint64($op))) $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) if $err != nil { panic($err) @@ -310,10 +439,23 @@ impl Bindgen for Func<'_> { let tmp = self.tmp(); let result = &format!("result{tmp}"); let operand = &operands[0]; - quote_in! { self.body => - $['\r'] - $result := $WAZERO_API_ENCODE_U32($operand) - }; + match self.direction { + Direction::Import { .. } => { + // For host functions (imports), just pass through the value + // The value is already uint32 and doesn't need encoding + quote_in! { self.body => + $['\r'] + $result := $operand + }; + } + Direction::Export { .. } => { + // For exports, encode the value for passing to Wasm + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_ENCODE_U32($operand) + }; + } + } results.push(Operand::SingleValue(result.into())); } Instruction::U32FromI32 => { @@ -322,7 +464,7 @@ impl Bindgen for Func<'_> { let operand = &operands[0]; quote_in! { self.body => $['\r'] - $result := $WAZERO_API_DECODE_U32($operand) + $result := $WAZERO_API_DECODE_U32(uint64($operand)) }; results.push(Operand::SingleValue(result.into())); } @@ -510,7 +652,7 @@ impl Bindgen for Func<'_> { } }; - results.push(Operand::MultiValue((value.into(), err.into()))); + results.push(Operand::DoubleValue(value.into(), err.into())); } Instruction::ResultLift { result: @@ -557,7 +699,7 @@ impl Bindgen for Func<'_> { } } Instruction::CallInterface { func, .. } => { - let ident = GoIdentifier::public(&func.name); + let ident = GoIdentifier::from_resource_function(&func.name); let tmp = self.tmp(); let args = quote!($(for op in operands.iter() join (, ) => $op)); let returns = match &func.result { @@ -567,23 +709,130 @@ impl Bindgen for Func<'_> { let value = &format!("value{tmp}"); let err = &format!("err{tmp}"); let ok = &format!("ok{tmp}"); - match self.direction { + + // Check if this is a resource constructor or method + let is_constructor = func.name.starts_with("[constructor]"); + let is_method = func.name.starts_with("[method]"); + + // Check if first parameter is a resource type + let first_param_is_resource = func.params.first().map_or(false, |(_, typ)| { + if let wit_bindgen_core::wit_parser::Type::Id(id) = typ { + let type_def = &resolve.types[*id]; + matches!( + type_def.kind, + wit_bindgen_core::wit_parser::TypeDefKind::Handle(_) + | wit_bindgen_core::wit_parser::TypeDefKind::Resource + ) + } else { + false + } + }); + + match &self.direction { Direction::Export { .. } => todo!("TODO(#10): handle export direction"), - Direction::Import { param_name, .. } => { - quote_in! { self.body => - $['\r'] - $(match returns { - GoType::Nothing => $param_name.$ident(ctx, $args), - GoType::Bool | GoType::Uint32 | GoType::Interface | GoType::String | GoType::UserDefined(_) => $value := $param_name.$ident(ctx, $args), - GoType::Error => $err := $param_name.$ident(ctx, $args), - GoType::ValueOrError(_) => { - $value, $err := $param_name.$ident(ctx, $args) + Direction::Import { + param_name, + resource_context, + } => { + if is_constructor && resource_context.is_some() { + // Constructor: call interface method, store in table, return handle + quote_in! { self.body => + $['\r'] + $(match returns { + GoType::OwnHandle(_) | GoType::Resource(_) => { + $value := $(*param_name).$ident(ctx, $args) + } + _ => $(comment(&["Unexpected return type for constructor"])) + }) + } + } else if is_method && resource_context.is_some() { + // Method: lookup resource from table, call method on resource + let ctx = resource_context.as_ref().unwrap(); + let table_var = &ctx.table_var; + let resource_var = &format!("resource{tmp}"); + let ok_var = &format!("ok{tmp}"); + // First operand should be the handle (self parameter) + let handle_operand = &operands[0]; + // Remaining operands are method parameters + let method_args = if operands.len() > 1 { + quote!($(for op in operands.iter().skip(1) join (, ) => $op)) + } else { + quote!() + }; + + quote_in! { self.body => + $['\r'] + $resource_var, $ok_var := $table_var.get($handle_operand) + if !$ok_var { + panic("invalid resource handle") } - GoType::ValueOrOk(_) => { - $value, $ok := $param_name.$ident(ctx, $args) + $(match returns { + GoType::Nothing => $resource_var.$(&ident)(ctx$(if !method_args.is_empty() => , $method_args)), + GoType::Bool | GoType::Uint32 | GoType::Interface | GoType::String | GoType::UserDefined(_) => $value := $resource_var.$(&ident)(ctx$(if !method_args.is_empty() => , $method_args)), + GoType::Error => $err := $resource_var.$(&ident)(ctx$(if !method_args.is_empty() => , $method_args)), + GoType::ValueOrError(_) => { + $value, $err := $resource_var.$(&ident)(ctx$(if !method_args.is_empty() => , $method_args)) + } + GoType::ValueOrOk(_) => { + $value, $ok := $resource_var.$(&ident)(ctx$(if !method_args.is_empty() => , $method_args)) + } + _ => $(comment(&["TODO(#9): handle return type"])) + }) + } + } else if resource_context.is_some() + && !operands.is_empty() + && first_param_is_resource + { + // Freestanding function with resource parameter: lookup resource from table + let ctx = resource_context.as_ref().unwrap(); + let table_var = &ctx.table_var; + let resource_var = &format!("resource{tmp}"); + let ok_var = &format!("ok{tmp}"); + // First operand should be the handle (resource parameter) + let handle_operand = &operands[0]; + // Remaining operands are other parameters + let remaining_args = if operands.len() > 1 { + quote!($(for op in operands.iter().skip(1) join (, ) => , $op)) + } else { + quote!() + }; + + quote_in! { self.body => + $['\r'] + $resource_var, $ok_var := $table_var.get($handle_operand) + if !$ok_var { + panic("invalid resource handle") } - _ => $(comment(&["TODO(#9): handle return type"])) - }) + $(match returns { + GoType::Nothing => $(*param_name).$ident(ctx, $resource_var$remaining_args), + GoType::Bool | GoType::Uint32 | GoType::Interface | GoType::String | GoType::UserDefined(_) | GoType::OwnHandle(_) | GoType::BorrowHandle(_) | GoType::Resource(_) => $value := $(*param_name).$ident(ctx, $resource_var$remaining_args), + GoType::Error => $err := $(*param_name).$ident(ctx, $resource_var$remaining_args), + GoType::ValueOrError(_) => { + $value, $err := $(*param_name).$ident(ctx, $resource_var$remaining_args) + } + GoType::ValueOrOk(_) => { + $value, $ok := $(*param_name).$ident(ctx, $resource_var$remaining_args) + } + _ => $(comment(&["TODO(#9): handle return type"])) + }) + } + } else { + // Regular interface call (not a resource constructor or method) + quote_in! { self.body => + $['\r'] + $(match returns { + GoType::Nothing => $(*param_name).$ident(ctx, $args), + GoType::Bool | GoType::Uint32 | GoType::Interface | GoType::String | GoType::UserDefined(_) | GoType::OwnHandle(_) | GoType::BorrowHandle(_) | GoType::Resource(_) => $value := $(*param_name).$ident(ctx, $args), + GoType::Error => $err := $(*param_name).$ident(ctx, $args), + GoType::ValueOrError(_) => { + $value, $err := $(*param_name).$ident(ctx, $args) + } + GoType::ValueOrOk(_) => { + $value, $ok := $(*param_name).$ident(ctx, $args) + } + _ => $(comment(&["TODO(#9): handle return type"])) + }) + } } } } @@ -593,17 +842,20 @@ impl Bindgen for Func<'_> { | GoType::Uint32 | GoType::Interface | GoType::UserDefined(_) - | GoType::String => { + | GoType::String + | GoType::OwnHandle(_) + | GoType::BorrowHandle(_) + | GoType::Resource(_) => { results.push(Operand::SingleValue(value.into())); } GoType::Error => { results.push(Operand::SingleValue(err.into())); } GoType::ValueOrError(_) => { - results.push(Operand::MultiValue((value.into(), err.into()))); + results.push(Operand::DoubleValue(value.into(), err.into())); } GoType::ValueOrOk(_) => { - results.push(Operand::MultiValue((value.into(), ok.into()))) + results.push(Operand::DoubleValue(value.into(), ok.into())) } _ => todo!("TODO(#9): handle return type - {returns:?}"), } @@ -619,16 +871,16 @@ impl Bindgen for Func<'_> { let ptr = &operands[1]; if let Operand::Literal(byte) = tag { match &self.direction { - Direction::Export => { + Direction::Export { .. } => { quote_in! { self.body => $['\r'] - i.module.Memory().WriteByte($ptr+$offset, $byte) + i.module.Memory().WriteByte(uint32($ptr+$offset), $byte) } } Direction::Import { .. } => { quote_in! { self.body => $['\r'] - mod.Memory().WriteByte($ptr+$offset, $byte) + mod.Memory().WriteByte(uint32($ptr+$offset), $byte) } } } @@ -636,7 +888,7 @@ impl Bindgen for Func<'_> { let tmp = self.tmp(); let byte = format!("byte{tmp}"); match &self.direction { - Direction::Export => { + Direction::Export { .. } => { quote_in! { self.body => $['\r'] var $(&byte) uint8 @@ -649,7 +901,7 @@ impl Bindgen for Func<'_> { $(comment(["TODO(#8): Return an error if the return type allows it"])) panic($ERRORS_NEW("invalid int8 value encountered")) } - i.module.Memory().WriteByte($ptr+$offset, $byte) + i.module.Memory().WriteByte(uint32($ptr+$offset), $byte) } } Direction::Import { .. } => { @@ -664,7 +916,7 @@ impl Bindgen for Func<'_> { default: panic($ERRORS_NEW("invalid int8 value encountered")) } - mod.Memory().WriteByte($ptr+$offset, $byte) + mod.Memory().WriteByte(uint32($ptr+$offset), $byte) } } } @@ -676,16 +928,16 @@ impl Bindgen for Func<'_> { let tag = &operands[0]; let ptr = &operands[1]; match &self.direction { - Direction::Export => { + Direction::Export { .. } => { quote_in! { self.body => $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, $tag) + i.module.Memory().WriteUint32Le(uint32($ptr+$offset), uint32($tag)) } } Direction::Import { .. } => { quote_in! { self.body => $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, $tag) + mod.Memory().WriteUint32Le(uint32($ptr+$offset), uint32($tag)) } } } @@ -696,16 +948,16 @@ impl Bindgen for Func<'_> { let len = &operands[0]; let ptr = &operands[1]; match &self.direction { - Direction::Export => { + Direction::Export { .. } => { quote_in! { self.body => $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, uint32($len)) + i.module.Memory().WriteUint32Le(uint32($ptr+$offset), uint32($len)) } } Direction::Import { .. } => { quote_in! { self.body => $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, uint32($len)) + mod.Memory().WriteUint32Le(uint32($ptr+$offset), uint32($len)) } } } @@ -716,16 +968,16 @@ impl Bindgen for Func<'_> { let value = &operands[0]; let ptr = &operands[1]; match &self.direction { - Direction::Export => { + Direction::Export { .. } => { quote_in! { self.body => $['\r'] - i.module.Memory().WriteUint32Le($ptr+$offset, uint32($value)) + i.module.Memory().WriteUint32Le(uint32($ptr+$offset), uint32($value)) } } Direction::Import { .. } => { quote_in! { self.body => $['\r'] - mod.Memory().WriteUint32Le($ptr+$offset, uint32($value)) + mod.Memory().WriteUint32Le(uint32($ptr+$offset), uint32($value)) } } } @@ -748,7 +1000,10 @@ impl Bindgen for Func<'_> { Operand::SingleValue(_) => panic!( "impossible: expected Operand::MultiValue but got Operand::SingleValue" ), - Operand::MultiValue(bindings) => bindings, + Operand::DoubleValue(ok, err) => (ok, err), + Operand::MultiValue(_) => panic!( + "impossible: expected Operand::DoubleValue but got Operand::MultiValue" + ), }; quote_in! { self.body => $['\r'] @@ -808,10 +1063,9 @@ impl Bindgen for Func<'_> { } }; - results.push(Operand::MultiValue((result.into(), ok.into()))); + results.push(Operand::DoubleValue(result.into(), ok.into())); } Instruction::OptionLower { - payload: Type::String, results: result_types, .. } => { @@ -820,6 +1074,10 @@ impl Bindgen for Func<'_> { let tmp = self.tmp(); + // If there are no result_types, then the payload will be a pointer, + // because that's how we represent optionals in Go. + let is_pointer = result_types.is_empty(); + let mut vars: Tokens = Tokens::new(); for i in 0..result_types.len() { let variant = &format!("variant{tmp}_{i}"); @@ -848,23 +1106,24 @@ impl Bindgen for Func<'_> { Operand::Literal(_) => { panic!("impossible: expected Operand::MultiValue but got Operand::Literal") } - // TODO(#7): This is a weird hack to implement `option` - // as arguments that currently only works for strings - // because it checks the empty string as the zero value to - // consider it None Operand::SingleValue(value) => { quote_in! { self.body => $['\r'] $vars - if $value == "" { + if $REFLECT_VALUE_OF($value).IsZero() { $none_block } else { - variantPayload := $value + variantPayload := $(if is_pointer => *)$value $some_block } }; } - Operand::MultiValue((value, ok)) => { + Operand::MultiValue(_) => { + panic!( + "impossible: expected Operand::DoubleValue but got Operand::MultiValue" + ) + } + Operand::DoubleValue(value, ok) => { quote_in! { self.body => $['\r'] if $ok { @@ -877,7 +1136,6 @@ impl Bindgen for Func<'_> { } }; } - Instruction::OptionLower { .. } => todo!("implement instruction: {inst:?}"), Instruction::RecordLower { record, .. } => { let tmp = self.tmp(); let operand = &operands[0]; @@ -894,10 +1152,59 @@ impl Bindgen for Func<'_> { Instruction::RecordLift { record, name, .. } => { let tmp = self.tmp(); let value = &format!("value{tmp}"); - let fields = record + + // Generate pointer conversion code for optional fields + let converted_operands: Vec<_> = record .fields .iter() .zip(operands) + .enumerate() + .map(|(i, (field, op))| { + let field_type = match resolve_type(&field.ty, resolve) { + GoType::ValueOrOk(inner_type) => GoType::Pointer(inner_type), + other => other, + }; + let op_clone = op.clone(); + quote_in! { self.body => + $['\r'] + }; + match (&field_type, &op_clone) { + (GoType::Pointer(inner_type), Operand::DoubleValue(val, ok)) => { + quote_in! { self.body => + $['\r'] + }; + let ptr_var_name = format!("ptr{tmp}x{i}"); + quote_in! { self.body => + $['\r'] + }; + let val_ident = GoIdentifier::local(val); + let ok_ident = GoIdentifier::local(ok); + let ptr_var_ident = &GoIdentifier::local(&ptr_var_name); + quote_in! { self.body => + $['\r'] + var $(ptr_var_ident) *$(inner_type.as_ref()) + if $(&ok_ident) { + $(ptr_var_ident) = &$(&val_ident) + } else { + $(ptr_var_ident) = nil + } + }; + Operand::SingleValue(ptr_var_name) + } + _ => { + quote_in! { self.body => + $['\r'] + }; + op_clone + } + } + }) + .collect(); + + let fields = record + .fields + .iter() + .zip(&converted_operands) .map(|(field, op)| (GoIdentifier::public(&field.name), op)); quote_in! {self.body => @@ -1004,9 +1311,10 @@ impl Bindgen for Func<'_> { let value = &operands[0]; let default = &format!("default{tmp}"); - for (i, typ) in result_types.iter().enumerate() { + for (i, _typ) in result_types.iter().enumerate() { let variant_item = &format!("variant{tmp}_{i}"); - let typ = resolve_wasm_type(typ); + // TODO: Use uint64 for all variant variables since they hold encoded WebAssembly values + let typ = GoType::Uint64; quote_in! { self.body => $['\r'] var $variant_item $typ @@ -1014,8 +1322,60 @@ impl Bindgen for Func<'_> { results.push(Operand::SingleValue(variant_item.into())); } + // Find the parent variant's name by comparing case names + let variant_name = resolve.types.iter().find_map(|(_, type_def)| { + if let TypeDefKind::Variant(v) = &type_def.kind { + // Compare case names to identify the matching variant + if v.cases.len() == variant.cases.len() + && v.cases + .iter() + .zip(variant.cases.iter()) + .all(|(a, b)| a.name == b.name) + { + type_def.name.as_ref() + } else { + None + } + } else { + None + } + }); + + let variant_name = match variant_name { + Some(name) => name, + None => { + eprintln!("Warning: Could not find variant name, using 'Unknown'"); + "Unknown" + } + }; + + // Pre-generate all prefixed case names to handle string lifetimes + let case_names: Vec = variant + .cases + .iter() + .map(|case| { + let capitalized_case = case + .name + .replace("-", " ") + .split_whitespace() + .map(|word| { + let mut chars = word.chars(); + match chars.next() { + None => String::new(), + Some(first) => { + first.to_uppercase().collect::() + chars.as_str() + } + } + }) + .collect::(); + format!("{}{}", variant_name, capitalized_case) + }) + .collect(); + let mut cases: Tokens = Tokens::new(); - for (case, (block, block_results)) in variant.cases.iter().zip(blocks) { + for ((_case, (block, block_results)), case_name) in + variant.cases.iter().zip(blocks).zip(case_names.iter()) + { let mut assignments: Tokens = Tokens::new(); for (i, result) in block_results.iter().enumerate() { let variant_item = &format!("variant{tmp}_{i}"); @@ -1025,7 +1385,7 @@ impl Bindgen for Func<'_> { }; } - let name = GoIdentifier::public(case.name.clone()); + let name = GoIdentifier::public(case_name.clone()); quote_in! { cases => $['\r'] case $name: @@ -1086,16 +1446,168 @@ impl Bindgen for Func<'_> { Instruction::I32Load8S { .. } => todo!("implement instruction: {inst:?}"), Instruction::I32Load16U { .. } => todo!("implement instruction: {inst:?}"), Instruction::I32Load16S { .. } => todo!("implement instruction: {inst:?}"), - Instruction::I64Load { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F32Load { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F64Load { .. } => todo!("implement instruction: {inst:?}"), + Instruction::I64Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read i64 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read i64 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read i64 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } + Instruction::F32Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read f64 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } + Instruction::F64Load { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tmp = self.tmp(); + let value = &format!("value{tmp}"); + let ok = &format!("ok{tmp}"); + let default = &format!("default{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $value, $ok := i.module.Memory().ReadUint64Le(uint32($operand + $offset)) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if !$ok { + var $default $(typ.as_ref()) + return $default, $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(GoType::Error) => { + if !$ok { + return $ERRORS_NEW("failed to read f64 from memory") + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if !$ok { + panic($ERRORS_NEW("failed to read f64 from memory")) + } + } + }) + }; + results.push(Operand::SingleValue(value.into())); + } Instruction::I32Store16 { .. } => todo!("implement instruction: {inst:?}"), Instruction::I64Store { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F32Store { .. } => todo!("implement instruction: {inst:?}"), - Instruction::F64Store { .. } => todo!("implement instruction: {inst:?}"), + Instruction::F32Store { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tag = &operands[0]; + let ptr = &operands[1]; + match &self.direction { + Direction::Export { .. } => { + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint64Le(uint32($ptr+$offset), $tag) + } + } + Direction::Import { .. } => { + quote_in! { self.body => + $['\r'] + mod.Memory().WriteUint64Le(uint32($ptr+$offset), $tag) + } + } + } + } + Instruction::F64Store { offset } => { + // TODO(#58): Support additional ArchitectureSize + let offset = offset.size_wasm32(); + let tag = &operands[0]; + let ptr = &operands[1]; + match &self.direction { + Direction::Export { .. } => { + quote_in! { self.body => + $['\r'] + i.module.Memory().WriteUint64Le(uint32($ptr+$offset), $tag) + } + } + Direction::Import { .. } => { + quote_in! { self.body => + $['\r'] + mod.Memory().WriteUint64Le(uint32($ptr+$offset), $tag) + } + } + } + } Instruction::I32FromChar => todo!("implement instruction: {inst:?}"), - Instruction::I64FromU64 => todo!("implement instruction: {inst:?}"), - Instruction::I64FromS64 => todo!("implement instruction: {inst:?}"), + Instruction::I64FromU64 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := int64($operand) + } + results.push(Operand::SingleValue(value.into())); + } + Instruction::I64FromS64 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := $operand + } + results.push(Operand::SingleValue(value.into())); + } Instruction::I32FromS32 => { let tmp = self.tmp(); let value = format!("value{tmp}"); @@ -1120,8 +1632,26 @@ impl Bindgen for Func<'_> { } results.push(Operand::SingleValue(value)) } - Instruction::CoreF32FromF32 => todo!("implement instruction: {inst:?}"), - Instruction::CoreF64FromF64 => todo!("implement instruction: {inst:?}"), + Instruction::CoreF32FromF32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_ENCODE_F32(float32($operand)) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::CoreF64FromF64 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_ENCODE_F64(float64($operand)) + }; + results.push(Operand::SingleValue(result.into())); + } // TODO: Validate the Go cast truncates the upper bits in the I32 Instruction::S8FromI32 => { let tmp = self.tmp(); @@ -1177,21 +1707,226 @@ impl Bindgen for Func<'_> { results.push(Operand::SingleValue(result.into())); } Instruction::S64FromI64 => todo!("implement instruction: {inst:?}"), - Instruction::U64FromI64 => todo!("implement instruction: {inst:?}"), + Instruction::U64FromI64 => { + let tmp = self.tmp(); + let value = format!("value{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $(&value) := uint64($operand) + } + results.push(Operand::SingleValue(value.into())); + } Instruction::CharFromI32 => todo!("implement instruction: {inst:?}"), - Instruction::F32FromCoreF32 => todo!("implement instruction: {inst:?}"), - Instruction::F64FromCoreF64 => todo!("implement instruction: {inst:?}"), - Instruction::TupleLower { .. } => todo!("implement instruction: {inst:?}"), - Instruction::TupleLift { .. } => todo!("implement instruction: {inst:?}"), + Instruction::F32FromCoreF32 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_DECODE_F32($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::F64FromCoreF64 => { + let tmp = self.tmp(); + let result = &format!("result{tmp}"); + let operand = &operands[0]; + quote_in! { self.body => + $['\r'] + $result := $WAZERO_API_DECODE_F64($operand) + }; + results.push(Operand::SingleValue(result.into())); + } + Instruction::TupleLower { tuple, .. } => { + let tmp = self.tmp(); + let operand = &operands[0]; + for (i, _) in tuple.types.iter().enumerate() { + let field = GoIdentifier::public(format!("f-{i}")); + let var = &GoIdentifier::local(format!("f-{tmp}-{i}")); + quote_in! { self.body => + $['\r'] + $var := $operand.$field + } + results.push(Operand::SingleValue(var.into())); + } + } + Instruction::TupleLift { tuple, ty } => { + if tuple.types.len() != operands.len() { + panic!( + "impossible: expected {} operands but got {}", + tuple.types.len(), + operands.len() + ); + } + let tmp = self.tmp(); + let value = &GoIdentifier::local(format!("value{tmp}")); + + let mut ty_tokens = Tokens::new(); + if let Some(ty) = resolve + .types + .get(ty.clone()) + .expect("failed to find tuple type definition") + .name + .as_ref() + { + let ty_name = GoIdentifier::public(ty); + ty_name.format_into(&mut ty_tokens); + } else { + ty_tokens.append(static_literal("struct{")); + if let Some((last, typs)) = tuple.types.split_last() { + for (i, typ) in typs.iter().enumerate() { + let go_type = resolve_type(typ, resolve); + let field = GoIdentifier::public(format!("f-{i}")); + field.format_into(&mut ty_tokens); + ty_tokens.space(); + go_type.format_into(&mut ty_tokens); + ty_tokens.append(static_literal(";")); + ty_tokens.space(); + } + let field = GoIdentifier::public(format!("f-{}", typs.len())); + field.format_into(&mut ty_tokens); + let go_type = resolve_type(last, resolve); + ty_tokens.space(); + ty_tokens.append(go_type); + } + ty_tokens.append(static_literal("}")); + } + quote_in! { self.body => + $['\r'] + var $value $ty_tokens + } + for (i, (operand, _)) in operands.iter().zip(&tuple.types).enumerate() { + let field = &GoIdentifier::public(format!("f-{i}")); + quote_in! { self.body => + $['\r'] + $value.$field = $operand + } + } + results.push(Operand::SingleValue(value.into())); + } Instruction::FlagsLower { .. } => todo!("implement instruction: {inst:?}"), Instruction::FlagsLift { .. } => todo!("implement instruction: {inst:?}"), Instruction::VariantLift { .. } => { todo!("implement instruction: {inst:?}") } Instruction::EnumLift { .. } => todo!("implement instruction: {inst:?}"), - Instruction::Malloc { .. } => todo!("implement instruction: {inst:?}"), - Instruction::HandleLower { .. } | Instruction::HandleLift { .. } => { - todo!("implement resources: {inst:?}") + Instruction::Malloc { + realloc, + size, + align, + } => { + let tmp = self.tmp(); + let ptr = &format!("ptr{tmp}"); + let result = &format!("result{tmp}"); + let err = &format!("err{tmp}"); + let default = &format!("default{tmp}"); + let size = size.size_wasm32(); + let align = align.align_wasm32(); + + quote_in! { self.body => + $['\r'] + $result, $err := i.module.ExportedFunction($(quoted(*realloc))).Call(ctx, 0, 0, $align, $size) + $(match &self.result { + GoResult::Anon(GoType::ValueOrError(typ)) => { + if $err != nil { + var $default $(typ.as_ref()) + return $default, $err + } + } + GoResult::Anon(GoType::Error) => { + if $err != nil { + return $err + } + } + GoResult::Anon(_) | GoResult::Empty => { + $(comment(&["The return type doesn't contain an error so we panic if one is encountered"])) + if $err != nil { + panic($err) + } + } + }) + $ptr := $result[0] + } + results.push(Operand::SingleValue(ptr.into())); + } + Instruction::HandleLower { + handle, + name: _, + ty: _ty, + } => match handle { + // Create an `i32` from a handle. + // For constructors, we need to store the resource in the table and return the handle. + // For other cases, just convert to uint32. + Handle::Own(_id) | Handle::Borrow(_id) => { + let tmp = self.tmp(); + let converted = &format!("converted{tmp}"); + let operand = &operands[0]; + + // Check if this is in a constructor context (we have resource_context and just called NewFoo) + match &self.direction { + Direction::Import { + resource_context, .. + } if resource_context.is_some() => { + let ctx = resource_context.as_ref().unwrap(); + let table_var = &ctx.table_var; + + quote_in! { self.body => + $['\r'] + $converted := uint32($table_var.Store($operand)) + } + } + _ => { + quote_in! { self.body => + $['\r'] + $converted := uint32($operand) + } + } + } + + results.push(Operand::SingleValue(converted.into())); + } + }, + Instruction::HandleLift { + handle, + name, + ty: _ty, + } => { + // Convert an i32 from Wasm into a resource handle. + // In the Component Model, the i32 is an index into the resource table. + // We need to get the proper resource type name (with interface prefix). + match handle { + Handle::Own(_id) | Handle::Borrow(_id) => { + let tmp = self.tmp(); + let converted = &format!("converted{tmp}"); + let operand = &operands[0]; + + // Use the properly prefixed resource type name + let resource_type = match &self.direction { + Direction::Import { + resource_context, .. + } + | Direction::Export { + resource_context, .. + } if resource_context.is_some() => { + let ctx = resource_context.as_ref().unwrap(); + // Use kebab-case format for proper identifier conversion + GoIdentifier::private(&format!( + "{}-{}-handle", + ctx.interface_name, ctx.resource_name + )) + } + _ => GoIdentifier::public(*name), + }; + + quote_in! { self.body => + $['\r'] + $converted := $resource_type($operand) + } + + results.push(Operand::SingleValue(converted.into())); + } + } } Instruction::ListCanonLower { .. } | Instruction::ListCanonLift { .. } => { unimplemented!("gravity doesn't represent lists as Canonical") diff --git a/cmd/gravity/src/codegen/imports.rs b/cmd/gravity/src/codegen/imports.rs index 5863ede..c418b07 100644 --- a/cmd/gravity/src/codegen/imports.rs +++ b/cmd/gravity/src/codegen/imports.rs @@ -4,7 +4,8 @@ use genco::prelude::*; use wit_bindgen_core::{ abi::{AbiVariant, LiftLower}, wit_parser::{ - Function, InterfaceId, Resolve, SizeAlign, Type, TypeDefKind, TypeId, World, WorldItem, + Function, InterfaceId, Resolve, SizeAlign, Type, TypeDef, TypeDefKind, TypeId, World, + WorldItem, }, }; @@ -17,10 +18,10 @@ use crate::{ }, }, go::{ - GoIdentifier, GoResult, GoType, + GoIdentifier, GoResult, GoType, comment, imports::{CONTEXT_CONTEXT, WAZERO_API_MODULE}, }, - resolve_type, + resolve_type, resolve_wasm_type, }; /// Analyzer for imports - only does analysis, no code generation @@ -56,6 +57,55 @@ impl<'a> ImportAnalyzer<'a> { } } + // Scan exports for resources that need [resource-new] and [resource-drop] host functions + let mut exported_resources = Vec::new(); + let world_exports = &self.world.exports; + + for (_export_name, world_item) in world_exports.iter() { + if let WorldItem::Interface { id, .. } = world_item { + let interface = &self.resolve.interfaces[*id]; + if let Some(interface_name) = &interface.name { + // Build the fully qualified interface name (e.g., "arcjet:resources/types-a") + let full_interface_name = if let Some(package_id) = &interface.package { + let package = &self.resolve.packages[*package_id]; + format!( + "{}:{}/{}", + package.name.namespace, package.name.name, interface_name + ) + } else { + interface_name.to_string() + }; + + // Get the short interface name for prefixing (e.g., "types-a") + let short_interface_name = + interface_name.split('/').last().unwrap_or(interface_name); + + // Check for resources in this exported interface + for &type_id in interface.types.values() { + let type_def = &self.resolve.types[type_id]; + if matches!( + type_def.kind, + wit_bindgen_core::wit_parser::TypeDefKind::Resource + ) { + if let Some(resource_name) = &type_def.name { + let prefixed_name = + format!("{}-{}", short_interface_name, resource_name); + let wazero_export_module_name = + format!("[export]{}", full_interface_name); + + exported_resources.push(crate::codegen::ir::ExportedResourceInfo { + interface_name: short_interface_name.to_string(), + resource_name: resource_name.clone(), + prefixed_name, + wazero_export_module_name, + }); + } + } + } + } + } + } + // Generate factory-related identifiers let factory_name = GoIdentifier::public(format!("{}-factory", self.world.name)); let instance_name = GoIdentifier::public(format!("{}-instance", self.world.name)); @@ -65,6 +115,7 @@ impl<'a> ImportAnalyzer<'a> { interfaces, standalone_types, standalone_functions, + exported_resources, factory_name, instance_name, constructor_name, @@ -131,7 +182,7 @@ impl<'a> ImportAnalyzer<'a> { InterfaceMethod { name: func.name.clone(), - go_method_name: GoIdentifier::public(&func.name), + go_method_name: GoIdentifier::from_resource_function(&func.name), parameters, return_type, wit_function: func.clone(), @@ -143,7 +194,7 @@ impl<'a> ImportAnalyzer<'a> { let type_name = type_def.name.as_ref().expect("type missing name"); let go_type_name = GoIdentifier::public(type_name); - let definition = self.analyze_type_definition(&type_def.kind); + let definition = self.analyze_type_definition(&type_def); definition.map(|definition| AnalyzedType { name: type_name.clone(), @@ -159,8 +210,8 @@ impl<'a> ImportAnalyzer<'a> { /// is probably a reference to an imported type that we have already analyzed. /// /// TODO: we should probably instead resolve and return type and dedup elsewhere. - fn analyze_type_definition(&self, kind: &TypeDefKind) -> Option { - Some(match kind { + fn analyze_type_definition(&self, type_def: &TypeDef) -> Option { + Some(match &type_def.kind { TypeDefKind::Record(record) => TypeDefinition::Record { fields: record .fields @@ -176,18 +227,26 @@ impl<'a> ImportAnalyzer<'a> { TypeDefKind::Enum(enum_def) => TypeDefinition::Enum { cases: enum_def.cases.iter().map(|c| c.name.clone()).collect(), }, - TypeDefKind::Variant(variant) => TypeDefinition::Variant { - cases: variant - .cases - .iter() - .map(|case| { - ( - case.name.clone(), - case.ty.as_ref().map(|t| resolve_type(t, self.resolve)), - ) - }) - .collect(), - }, + TypeDefKind::Variant(variant) => { + let interface_name = type_def.name.clone().expect("variant should have a name"); + TypeDefinition::Variant { + interface_function_name: GoIdentifier::private(format!( + "is-{}", + interface_name, + )), + cases: variant + .cases + .iter() + .map(|case| { + ( + // TODO(bsull): prefix these with the interface name. + GoIdentifier::public(format!("{}-{}", interface_name, case.name)), + case.ty.as_ref().map(|t| resolve_type(t, self.resolve)), + ) + }) + .collect(), + } + } TypeDefKind::Type(Type::Id(_)) => { // TODO(#4): Only skip this if we have already generated the type return None; @@ -215,13 +274,33 @@ impl<'a> ImportAnalyzer<'a> { } TypeDefKind::Option(_) => todo!("TODO(#4): generate option type definition"), TypeDefKind::Result(_) => todo!("TODO(#4): generate result type definition"), - TypeDefKind::List(_) => todo!("TODO(#4): generate list type definition"), + TypeDefKind::List(ty) => TypeDefinition::Alias { + target: GoType::Slice(Box::new(resolve_type(ty, self.resolve))), + }, TypeDefKind::Future(_) => todo!("TODO(#4): generate future type definition"), TypeDefKind::Stream(_) => todo!("TODO(#4): generate stream type definition"), TypeDefKind::Flags(_) => todo!("TODO(#4):generate flags type definition"), - TypeDefKind::Tuple(_) => todo!("TODO(#4):generate tuple type definition"), - TypeDefKind::Resource => todo!("TODO(#5): implement resources"), - TypeDefKind::Handle(_) => todo!("TODO(#5): implement resources"), + TypeDefKind::Tuple(tuple) => TypeDefinition::Record { + fields: tuple + .types + .iter() + .enumerate() + .map(|(i, t)| { + ( + GoIdentifier::public(format!("f-{i}")), + resolve_type(t, self.resolve), + ) + }) + .collect(), + }, + TypeDefKind::Resource => { + // Resources are handled separately as interfaces with methods + return None; + } + TypeDefKind::Handle(_) => { + // Handles are handled separately in the resource implementation + return None; + } TypeDefKind::Unknown => panic!("cannot generate Unknown type"), }) } @@ -280,13 +359,53 @@ impl<'a> ImportCodeGenerator<'a> { for method in &interface.methods { chain.push(); - let func_builder = - self.generate_host_function_builder(method, &interface.constructor_param_name); + let interface_name = interface.name.split('/').last().unwrap_or(&interface.name); + let func_builder = self.generate_host_function_builder( + method, + &interface.constructor_param_name, + interface_name, + ); quote_in! { chain => $func_builder }; } + // Generate drop handlers for resources + let mut resource_names = std::collections::HashSet::new(); + for method in &interface.methods { + if method.name.contains("[constructor]") { + if let Some(ret) = &method.return_type { + if let crate::go::GoType::OwnHandle(name) + | crate::go::GoType::BorrowHandle(name) + | crate::go::GoType::Resource(name) = &ret.go_type + { + resource_names.insert(name.clone()); + } + } + } + } + + for resource_name in resource_names { + chain.push(); + quote_in! { chain => + NewFunctionBuilder(). + WithFunc(func(ctx $CONTEXT_CONTEXT, mod $WAZERO_API_MODULE, arg0 uint32) { + $(comment(&[ + "[resource-drop]: called when guest drops a resource", + "", + "With borrow-only parameters, guests never take ownership of host resources.", + "Resources stay in host table until host explicitly removes them.", + "This callback is a no-op since host controls the full lifecycle.", + "", + "Note: If we add owned parameter support in the future, this would need", + "to implement ref-counting and state tracking to properly cleanup consumed resources.", + ])) + _ = arg0 + }). + Export($(quoted(&format!("[resource-drop]{}", resource_name)))). + }; + } + chain.push(); quote_in! { chain => Instantiate(ctx) @@ -306,6 +425,9 @@ impl FormatInto for ImportCodeGenerator<'_> { fn format_into(self, tokens: &mut Tokens) { // Generate interface type definitions for interface in &self.analyzed.interfaces { + // Generate resource interfaces first + self.generate_resource_interfaces(interface, tokens); + self.generate_interface_type(interface, tokens); for typ in &interface.types { @@ -321,17 +443,303 @@ impl FormatInto for ImportCodeGenerator<'_> { } impl<'a> ImportCodeGenerator<'a> { + fn generate_resource_interfaces(&self, interface: &AnalyzedInterface, tokens: &mut Tokens) { + use std::collections::{HashMap, HashSet}; + + // Collect methods by resource name + let interface_name = interface.name.split('/').last().unwrap_or(&interface.name); + let mut resource_methods: HashMap> = HashMap::new(); + let mut resource_names = HashSet::new(); + + for method in &interface.methods { + // Only include actual resource methods (not freestanding functions with resource params) + if !method.name.contains("[method]") { + continue; + } + + // Check if this is a resource method (has self parameter or returns resource) + let resource_name = if let Some(param) = method.parameters.first() { + match ¶m.go_type { + GoType::OwnHandle(name) + | GoType::BorrowHandle(name) + | GoType::Resource(name) => Some(name.clone()), + _ => None, + } + } else if let Some(ret) = &method.return_type { + match &ret.go_type { + GoType::OwnHandle(name) + | GoType::BorrowHandle(name) + | GoType::Resource(name) => Some(name.clone()), + _ => None, + } + } else { + None + }; + + if let Some(resource_name) = resource_name { + resource_names.insert(resource_name.clone()); + resource_methods + .entry(resource_name) + .or_insert_with(Vec::new) + .push(method); + } + } + + // First generate handle type aliases for each resource + for resource_name in &resource_names { + let prefixed_name = format!("{}-{}", interface_name, resource_name); + let handle_name = format!("{}-handle", prefixed_name); + let go_name = GoIdentifier::private(&handle_name); + let comment_text = format!( + "{} is a handle to the {} resource in the {} interface.", + String::from(&go_name), + resource_name, + interface_name + ); + quote_in! { *tokens => + $['\n'] + $(comment(&[comment_text.as_str()])) + type $go_name uint32 + } + } + + // Generate interface for each resource + for (resource_name, methods) in resource_methods { + let prefixed_name = format!("{}-{}", interface_name, resource_name); + let interface_type_name = GoIdentifier::public(&prefixed_name); + + // Filter out constructor and generate method signatures + let method_sigs = methods + .iter() + .filter(|m| !m.name.contains("[constructor]")) + .map(|method| { + let method_name = &method.go_method_name; + + // Skip 'self' parameter for method signatures + let params = method + .parameters + .iter() + .skip(1) + .map(|p| quote!($(&p.name) $(&p.go_type))) + .collect::>(); + + let return_type = method + .return_type + .as_ref() + .map(|t| GoResult::Anon(t.go_type.clone())) + .unwrap_or(GoResult::Empty); + + // Add context parameter at the beginning + if params.is_empty() { + quote!($method_name(ctx $CONTEXT_CONTEXT) $return_type) + } else { + quote!($method_name(ctx $CONTEXT_CONTEXT, $(for p in params join (, ) => $p)) $return_type) + } + }) + .collect::>(); + + quote_in! { *tokens => + $['\n'] + type $interface_type_name interface { + $(for sig in method_sigs join ($['\r']) => $sig) + } + } + } + } + fn generate_interface_type(&self, interface: &AnalyzedInterface, tokens: &mut Tokens) { + // Collect resource names and their type parameters + let interface_name = interface.name.split('/').last().unwrap_or(&interface.name); + let mut resource_names = Vec::new(); + let mut resource_type_params: Vec<(GoIdentifier, GoIdentifier, GoIdentifier)> = Vec::new(); + + for method in &interface.methods { + if method.name.contains("[constructor]") { + if let Some(ret) = &method.return_type { + if let GoType::OwnHandle(name) + | GoType::BorrowHandle(name) + | GoType::Resource(name) = &ret.go_type + { + let prefixed_name = format!("{}-{}", interface_name, name); + let pointer_interface_name = + GoIdentifier::public(format!("p-{}", prefixed_name)); + let value_type_param = + GoIdentifier::public(format!("t-{}-value", prefixed_name)); + let pointer_type_param = + GoIdentifier::public(format!("p-t-{}", prefixed_name)); + resource_names.push(name.clone()); + // Store the actual identifiers, not quotes + resource_type_params.push(( + value_type_param, + pointer_type_param, + pointer_interface_name, + )); + } + } + } + } + + // Generate method signatures with type parameters + // Include constructors and freestanding functions (which may take/return resources) let methods = interface .methods .iter() - .map(|method| self.generate_method_signature(method)); + .filter(|m| { + m.name.contains("[constructor]") || + (!m.name.contains("[method]") && !m.name.contains("[static]")) + }) + .map(|method| { + let is_constructor = method.name.contains("[constructor]"); + let is_freestanding = !method.name.contains("[constructor]") && + !method.name.contains("[method]") && + !method.name.contains("[static]"); + + let mut sig = self.generate_method_signature_with_interface(method, interface_name); + + // Replace return type with type parameter for constructors + if is_constructor { + if let Some(ret) = &method.return_type { + if let GoType::OwnHandle(name) | GoType::BorrowHandle(name) | GoType::Resource(name) = &ret.go_type { + let prefixed_name = format!("{}-{}", interface_name, name); + // Use value type parameter for return type (not pointer) + let value_type_param = GoIdentifier::public(format!("t-{}-value", prefixed_name)); + // Add context parameter at the beginning + if method.parameters.is_empty() { + sig = quote!($(&method.go_method_name)(ctx $CONTEXT_CONTEXT) $value_type_param); + } else { + sig = quote!($(&method.go_method_name)(ctx $CONTEXT_CONTEXT, $(for p in &method.parameters join (, ) => $(&p.name) $(&p.go_type))) $value_type_param); + } + } + } + } else if is_freestanding { + // For freestanding functions, replace resource parameters with type parameters + let params_with_ctx = std::iter::once(quote!(ctx $CONTEXT_CONTEXT)).chain( + method.parameters.iter().map(|p| { + let param_type = match &p.go_type { + GoType::BorrowHandle(name) => { + // Use pointer type parameter for borrowed resource params + let prefixed_name = format!("{}-{}", interface_name, name); + let pointer_type_param = GoIdentifier::public(format!("p-t-{}", prefixed_name)); + quote!($pointer_type_param) + } + GoType::OwnHandle(name) | GoType::Resource(name) => { + // Use value type parameter for owned/resource params + let prefixed_name = format!("{}-{}", interface_name, name); + let value_type_param = GoIdentifier::public(format!("t-{}-value", prefixed_name)); + quote!($value_type_param) + } + _ => { + let resolved = self.resolve_type_with_interface(&p.go_type, interface_name); + quote!($(&resolved)) + } + }; + quote!($(&p.name) $param_type) + }) + ); + + let return_type = method.return_type.as_ref().map(|r| { + match &r.go_type { + GoType::BorrowHandle(name) => { + // Use pointer type parameter for borrowed resource return types + let prefixed_name = format!("{}-{}", interface_name, name); + let pointer_type_param = GoIdentifier::public(format!("p-t-{}", prefixed_name)); + quote!($pointer_type_param) + } + GoType::OwnHandle(name) | GoType::Resource(name) => { + // Use value type parameter for owned/resource return types + let prefixed_name = format!("{}-{}", interface_name, name); + let value_type_param = GoIdentifier::public(format!("t-{}-value", prefixed_name)); + quote!($value_type_param) + } + _ => { + let resolved = self.resolve_type_with_interface(&r.go_type, interface_name); + quote!($(&resolved)) + } + } + }); + + if let Some(ret) = return_type { + sig = quote!($(&method.go_method_name)($(for p in params_with_ctx join (, ) => $p)) $ret); + } else { + sig = quote!($(&method.go_method_name)($(for p in params_with_ctx join (, ) => $p))); + } + } + + sig + }); + + if resource_type_params.is_empty() { + // No resources, generate regular interface + // Only include freestanding functions (not resource methods) + let methods = interface + .methods + .iter() + .filter(|m| { + !m.name.contains("[method]") + && !m.name.contains("[static]") + && !m.name.contains("[constructor]") + }) + .map(|method| self.generate_method_signature(method)); + + quote_in! { *tokens => + $['\n'] + type $(&interface.go_interface_name) interface { + $(for method in methods join ($['\r']) => $method) + } + } + } else { + // Generate generic interface with type parameters + quote_in! { *tokens => + $['\n'] + type $(&interface.go_interface_name)[$(for (value_param, pointer_param, pointer_iface) in &resource_type_params join (, ) => $value_param any, $pointer_param $pointer_iface[$value_param])] interface { + $(for method in methods join ($['\r']) => $method) + } + } + } + } + + /// Resolve a type with interface context, adding prefixes to resource handles + fn resolve_type_with_interface(&self, typ: &GoType, interface_name: &str) -> GoType { + match typ { + GoType::OwnHandle(name) | GoType::BorrowHandle(name) => { + let prefixed_name = format!("{}-{}", interface_name, name); + GoType::Resource(prefixed_name) + } + GoType::Resource(name) => { + // Check if it's already prefixed + if name.contains('-') { + typ.clone() + } else { + let prefixed_name = format!("{}-{}", interface_name, name); + GoType::Resource(prefixed_name) + } + } + _ => typ.clone(), + } + } + + fn generate_method_signature_with_interface( + &self, + method: &InterfaceMethod, + interface_name: &str, + ) -> Tokens { + let return_type = method + .return_type + .as_ref() + .map(|r| self.resolve_type_with_interface(&r.go_type, interface_name)); + + let params_with_ctx = std::iter::once(quote!(ctx $CONTEXT_CONTEXT)).chain( + method.parameters.iter().map(|p| { + let resolved_type = self.resolve_type_with_interface(&p.go_type, interface_name); + quote!($(&p.name) $(&resolved_type)) + }), + ); - quote_in! { *tokens => - $['\n'] - type $(&interface.go_interface_name) interface { - $(for method in methods join ($['\r']) => $method) + match return_type { + Some(typ) => { + quote!($(&method.go_method_name)($(for p in params_with_ctx join (, ) => $p)) $(&typ)) } + None => quote!($(&method.go_method_name)($(for p in params_with_ctx join (, ) => $p))), } } @@ -353,10 +761,17 @@ impl<'a> ImportCodeGenerator<'a> { fn generate_type_definition(&self, typ: &AnalyzedType, tokens: &mut Tokens) { match &typ.definition { TypeDefinition::Record { fields } => { + let maybe_pointer_fields = fields.iter().map(|(name, typ)| { + if let GoType::ValueOrOk(inner_type) = typ { + (name, GoType::Pointer(inner_type.clone())) + } else { + (name, typ.clone()) + } + }); quote_in! { *tokens => $['\n'] type $(&typ.go_type_name) struct { - $(for (field_name, field_type) in fields join ($['\n']) => + $(for (field_name, field_type) in maybe_pointer_fields join ($['\n']) => $field_name $field_type ) } @@ -396,10 +811,32 @@ impl<'a> ImportCodeGenerator<'a> { // Primitive type: $(typ.name) } } - TypeDefinition::Variant { .. } => { + TypeDefinition::Variant { + interface_function_name, + cases, + } => { quote_in! { *tokens => $['\n'] - // Variant type: $(typ.name) (TODO: implement) + type $(&typ.go_type_name) interface { + $(interface_function_name)() + } + $['\n'] + } + + for (case_name, case_type) in cases { + if let Some(inner_type) = case_type { + quote_in! { *tokens => + $['\n'] + type $case_name $inner_type + func ($case_name) $interface_function_name() {} + } + } else { + quote_in! { *tokens => + $['\n'] + type $&case_name $&inner_type + func ($&case_name) $&variant_function() {} + } + } } } } @@ -411,6 +848,8 @@ impl<'a> ImportCodeGenerator<'a> { // The name of the parameter representing the interface instance // in the generated function. param_name: &GoIdentifier, + // The interface name for resource table lookup + interface_name: &str, ) -> Tokens { let func_name = &method.name; @@ -425,10 +864,171 @@ impl<'a> ImportCodeGenerator<'a> { .wasm_signature(AbiVariant::GuestImport, &method.wit_function); let result = if wasm_sig.results.is_empty() { GoResult::Empty + } else if wasm_sig.results.len() == 1 { + GoResult::Anon(resolve_wasm_type(&wasm_sig.results[0])) } else { - todo!("implement handling of wasm signatures with results"); + GoResult::Anon(GoType::MultiReturn( + wasm_sig.results.iter().map(resolve_wasm_type).collect(), + )) + }; + // Detect if this is a resource constructor or method + let func_name_str = &method.name; + let mut f = if func_name_str.starts_with("[constructor]") + || func_name_str.starts_with("[method]") + { + // Extract resource name + let resource_name = if func_name_str.starts_with("[constructor]") { + func_name_str.strip_prefix("[constructor]").unwrap() + } else if func_name_str.starts_with("[method]") { + // Format: "[method]resource-name.method-name" + let parts: Vec<&str> = func_name_str + .strip_prefix("[method]") + .unwrap() + .split('.') + .collect(); + parts[0] + } else { + "" + }; + + // Convert to camelCase for table variable name + let interface_pascal = interface_name + .split('-') + .map(|s| { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } + }) + .collect::>() + .join(""); + let resource_pascal = resource_name + .split('-') + .map(|s| { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } + }) + .collect::>() + .join(""); + + // Convert PascalCase to camelCase by lowercasing first character + let interface_camel = { + let mut c = interface_pascal.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_lowercase().collect::() + c.as_str(), + } + }; + + let table_var = format!("{}{}ResourceTable", interface_camel, resource_pascal); + + Func::import_with_resource( + param_name, + result, + self.sizes, + interface_name.to_string(), + resource_name.to_string(), + table_var, + ) + } else { + // Freestanding function - check if it has resource parameters or returns a resource + // If so, extract resource info from the first resource parameter or return type + let resource_param = method.wit_function.params.iter().find_map(|(_, typ)| { + if let wit_bindgen_core::wit_parser::Type::Id(id) = typ { + let type_def = &self.resolve.types[*id]; + if let wit_bindgen_core::wit_parser::TypeDefKind::Handle(handle) = + &type_def.kind + { + match handle { + wit_bindgen_core::wit_parser::Handle::Own(resource_id) + | wit_bindgen_core::wit_parser::Handle::Borrow(resource_id) => { + let resource_def = &self.resolve.types[*resource_id]; + resource_def.name.as_ref().map(|name| name.clone()) + } + } + } else { + None + } + } else { + None + } + }); + + // Also check return type for resources + let resource_return = method.wit_function.result.as_ref().and_then(|ret_typ| { + if let wit_bindgen_core::wit_parser::Type::Id(id) = ret_typ { + let type_def = &self.resolve.types[*id]; + match &type_def.kind { + wit_bindgen_core::wit_parser::TypeDefKind::Handle(handle) => match handle { + wit_bindgen_core::wit_parser::Handle::Own(resource_id) + | wit_bindgen_core::wit_parser::Handle::Borrow(resource_id) => { + let resource_def = &self.resolve.types[*resource_id]; + resource_def.name.as_ref().map(|name| name.clone()) + } + }, + wit_bindgen_core::wit_parser::TypeDefKind::Resource => { + type_def.name.as_ref().map(|name| name.clone()) + } + _ => None, + } + } else { + None + } + }); + + let resource_name = resource_param.or(resource_return); + + if let Some(resource_name) = resource_name { + // Build the resource context for freestanding function with resource params + let interface_pascal = interface_name + .split('-') + .map(|s| { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } + }) + .collect::>() + .join(""); + let resource_pascal = resource_name + .split('-') + .map(|s| { + let mut c = s.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_uppercase().collect::() + c.as_str(), + } + }) + .collect::>() + .join(""); + + let interface_camel = { + let mut c = interface_pascal.chars(); + match c.next() { + None => String::new(), + Some(f) => f.to_lowercase().collect::() + c.as_str(), + } + }; + + let table_var = format!("{}{}ResourceTable", interface_camel, resource_pascal); + + Func::import_with_resource( + param_name, + result, + self.sizes, + interface_name.to_string(), + resource_name, + table_var, + ) + } else { + Func::import(param_name, result, self.sizes) + } }; - let mut f = Func::import(param_name, result, self.sizes); // Magic wit_bindgen_core::abi::call( @@ -489,6 +1089,7 @@ mod tests { let analyzed = AnalyzedImports { instance_name: GoIdentifier::public("TestInstance"), interfaces: vec![], + exported_resources: vec![], standalone_functions: vec![], standalone_types: vec![], factory_name: GoIdentifier::public("TestFactory"), @@ -512,7 +1113,8 @@ mod tests { }; let param_name = GoIdentifier::private("handler"); - let result = generator.generate_host_function_builder(&method, ¶m_name); + let result = + generator.generate_host_function_builder(&method, ¶m_name, "test-interface"); // The result should contain the WIT type-driven generation let code_str = result.to_string().unwrap(); @@ -531,6 +1133,7 @@ mod tests { interfaces: vec![], standalone_functions: vec![], standalone_types: vec![], + exported_resources: vec![], factory_name: GoIdentifier::public("TestFactory"), constructor_name: GoIdentifier::public("NewTestFactory"), }; @@ -560,7 +1163,8 @@ mod tests { }; let param_name = GoIdentifier::private("handler"); - let result = generator.generate_host_function_builder(&u32_method, ¶m_name); + let result = + generator.generate_host_function_builder(&u32_method, ¶m_name, "test-interface"); // Should have only one uint32 parameter (plus ctx and mod) let code_str = result.to_string().unwrap(); @@ -778,7 +1382,7 @@ mod tests { // Test analyze_type_definition directly with the record kind let type_def = &resolve.types[type_id]; - let analyzed_definition = analyzer.analyze_type_definition(&type_def.kind).unwrap(); + let analyzed_definition = analyzer.analyze_type_definition(&type_def).unwrap(); println!( "Direct analysis of type definition: {:?}", @@ -1000,7 +1604,7 @@ mod tests { // Test record analysis let record_def = &resolve.types[record_type_id]; - let record_analysis = analyzer.analyze_type_definition(&record_def.kind).unwrap(); + let record_analysis = analyzer.analyze_type_definition(&record_def).unwrap(); match record_analysis { TypeDefinition::Record { .. } => { @@ -1013,7 +1617,7 @@ mod tests { // Test alias analysis let alias_def = &resolve.types[alias_type_id]; - let alias_analysis = analyzer.analyze_type_definition(&alias_def.kind).unwrap(); + let alias_analysis = analyzer.analyze_type_definition(&alias_def).unwrap(); match alias_analysis { TypeDefinition::Alias { .. } => { diff --git a/cmd/gravity/src/codegen/ir.rs b/cmd/gravity/src/codegen/ir.rs index 5a79de3..7f87f3f 100644 --- a/cmd/gravity/src/codegen/ir.rs +++ b/cmd/gravity/src/codegen/ir.rs @@ -2,6 +2,20 @@ use wit_bindgen_core::wit_parser::{Function, Type}; use crate::go::{GoIdentifier, GoType}; +/// Information about a resource that is exported from a WASM module. +/// These resources need [resource-new] and [resource-drop] host functions. +#[derive(Debug, Clone)] +pub struct ExportedResourceInfo { + /// The interface name (e.g., "types-a") + pub interface_name: String, + /// The resource name (e.g., "foo") + pub resource_name: String, + /// The prefixed name (e.g., "types-a-foo") + pub prefixed_name: String, + /// The wazero module name for exports (e.g., "[export]arcjet:resources/types-a") + pub wazero_export_module_name: String, +} + /// An analyzed WIT import. #[derive(Debug, Clone)] pub struct AnalyzedImports { @@ -11,6 +25,8 @@ pub struct AnalyzedImports { pub standalone_types: Vec, /// All standalone functions found in the input world. pub standalone_functions: Vec, + /// All exported resources that need [resource-new] and [resource-drop] host functions. + pub exported_resources: Vec, /// The name of the factory type to be generated. pub factory_name: GoIdentifier, @@ -115,7 +131,16 @@ pub enum TypeDefinition { Record { fields: Vec<(GoIdentifier, GoType)> }, /// A union-like type with multiple cases, each optionally carrying data Variant { - cases: Vec<(String, Option)>, + /// The Go identifier to use for the interface function. + /// + /// E.g. the `isFoo` in `type Foo interface { isFoo() }` + interface_function_name: GoIdentifier, + + /// The cases of the variant type. + /// + /// The first element of each tuple is the prefixed name of the case, + /// where the prefix is the interface name. + cases: Vec<(GoIdentifier, Option)>, }, /// A simple enumeration with named constants Enum { cases: Vec }, diff --git a/cmd/gravity/src/go/identifier.rs b/cmd/gravity/src/go/identifier.rs index 7baaebc..46148e9 100644 --- a/cmd/gravity/src/go/identifier.rs +++ b/cmd/gravity/src/go/identifier.rs @@ -43,6 +43,45 @@ impl GoIdentifier { Self::Local { name: name.into() } } + /// Creates a public identifier from a resource function name. + /// + /// Resource function names in WIT have special prefixes: + /// - `[constructor]foo` → `NewFoo` + /// - `[method]foo.get-x` → `GetX` + /// - `[method]foo.increase-x` → `IncreaseX` + /// + /// For regular function names without these prefixes, this behaves + /// the same as `GoIdentifier::public()`. + pub fn from_resource_function(name: T) -> Self + where + T: AsRef, + { + let name = name.as_ref(); + + // Handle [constructor]resource-name + if let Some(resource_name) = name.strip_prefix("[constructor]") { + return Self::Public { + name: format!("new-{}", resource_name), + }; + } + + // Handle [method]resource-name.method-name + if let Some(rest) = name.strip_prefix("[method]") { + // Split on the first dot to separate resource name from method name + if let Some(dot_pos) = rest.find('.') { + let method_name = &rest[dot_pos + 1..]; + return Self::Public { + name: method_name.to_string(), + }; + } + } + + // For regular function names, just use public + Self::Public { + name: name.to_string(), + } + } + /// Returns an iterator over the characters of the underlying name. /// /// This provides access to the raw name without case transformations. @@ -134,4 +173,36 @@ mod tests { (&id).format_into(&mut tokens); assert_eq!(tokens.to_string().unwrap(), "helloWorld"); } + + #[test] + fn test_resource_constructor() { + let id = GoIdentifier::from_resource_function("[constructor]foo"); + let mut tokens = Tokens::::new(); + (&id).format_into(&mut tokens); + assert_eq!(tokens.to_string().unwrap(), "NewFoo"); + } + + #[test] + fn test_resource_method() { + let id = GoIdentifier::from_resource_function("[method]foo.get-x"); + let mut tokens = Tokens::::new(); + (&id).format_into(&mut tokens); + assert_eq!(tokens.to_string().unwrap(), "GetX"); + } + + #[test] + fn test_resource_method_multi_word() { + let id = GoIdentifier::from_resource_function("[method]foo.increase-x"); + let mut tokens = Tokens::::new(); + (&id).format_into(&mut tokens); + assert_eq!(tokens.to_string().unwrap(), "IncreaseX"); + } + + #[test] + fn test_regular_function_name() { + let id = GoIdentifier::from_resource_function("regular-function"); + let mut tokens = Tokens::::new(); + (&id).format_into(&mut tokens); + assert_eq!(tokens.to_string().unwrap(), "RegularFunction"); + } } diff --git a/cmd/gravity/src/go/imports.rs b/cmd/gravity/src/go/imports.rs index d6aa03d..00a3ecd 100644 --- a/cmd/gravity/src/go/imports.rs +++ b/cmd/gravity/src/go/imports.rs @@ -28,3 +28,13 @@ pub static WAZERO_API_ENCODE_I32: GoImport = GoImport("github.com/tetratelabs/wazero/api", "EncodeI32"); pub static WAZERO_API_DECODE_I32: GoImport = GoImport("github.com/tetratelabs/wazero/api", "DecodeI32"); +pub static WAZERO_API_ENCODE_F32: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "EncodeF32"); +pub static WAZERO_API_DECODE_F32: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "DecodeF32"); +pub static WAZERO_API_ENCODE_F64: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "EncodeF64"); +pub static WAZERO_API_DECODE_F64: GoImport = + GoImport("github.com/tetratelabs/wazero/api", "DecodeF64"); +pub static REFLECT_VALUE_OF: GoImport = GoImport("reflect", "ValueOf"); +pub static SYNC_MUTEX: GoImport = GoImport("sync", "Mutex"); diff --git a/cmd/gravity/src/go/operand.rs b/cmd/gravity/src/go/operand.rs index 60d7061..461b86f 100644 --- a/cmd/gravity/src/go/operand.rs +++ b/cmd/gravity/src/go/operand.rs @@ -13,25 +13,12 @@ pub enum Operand { Literal(String), /// A single variable or expression SingleValue(String), - /// A tuple of two values (for multi-value returns) - MultiValue((String, String)), -} - -impl Operand { - /// Returns the primary value of the operand. + /// A tuple of two values (shortcut for for multi-value returns). /// - /// For single values and literals, returns the value itself. - /// For multi-value tuples, returns the first value. - /// - /// # Returns - /// A string representation of the primary value. - pub fn as_string(&self) -> String { - match self { - Operand::Literal(s) => s.clone(), - Operand::SingleValue(s) => s.clone(), - Operand::MultiValue((s1, _)) => s1.clone(), - } - } + /// This is used when returning `val, ok` or `result, err` from Go functions. + DoubleValue(String, String), + /// A tuple of two or more values (for tuples) + MultiValue(Vec), } // Implement genco's FormatInto for Operand so it can be used in quote! macros @@ -40,11 +27,21 @@ impl FormatInto for &Operand { match self { Operand::Literal(val) => tokens.append(ItemStr::from(val)), Operand::SingleValue(val) => tokens.append(ItemStr::from(val)), - Operand::MultiValue((val1, val2)) => { - tokens.append(ItemStr::from(val1)); + Operand::DoubleValue(val, ok) => { + tokens.append(ItemStr::from(val)); tokens.append(static_literal(",")); tokens.space(); - tokens.append(ItemStr::from(val2)); + tokens.append(ItemStr::from(ok)); + } + Operand::MultiValue(vals) => { + if let Some((last, vals)) = vals.split_last() { + for val in vals.iter() { + tokens.append(val); + tokens.append(static_literal(",")); + tokens.space(); + } + tokens.append(last); + } } } } @@ -86,10 +83,22 @@ mod tests { } #[test] - fn test_operand_multi_value() { - let op = Operand::MultiValue(("val1".to_string(), "val2".to_string())); + fn test_operand_double_value() { + let op = Operand::DoubleValue("val1".to_string(), "val2".to_string()); let mut tokens = Tokens::::new(); op.format_into(&mut tokens); assert_eq!(tokens.to_string().unwrap(), "val1, val2"); } + + #[test] + fn test_operand_multi_value() { + let op = Operand::MultiValue(vec![ + "val1".to_string(), + "val2".to_string(), + "val3".to_string(), + ]); + let mut tokens = Tokens::::new(); + op.format_into(&mut tokens); + assert_eq!(tokens.to_string().unwrap(), "val1, val2, val3"); + } } diff --git a/cmd/gravity/src/go/type.rs b/cmd/gravity/src/go/type.rs index 47cfaba..fc96546 100644 --- a/cmd/gravity/src/go/type.rs +++ b/cmd/gravity/src/go/type.rs @@ -37,7 +37,7 @@ pub enum GoType { /// Interface type (for variants/discriminated unions) Interface, // Pointer to another type - // Pointer(Box), + Pointer(Box), /// Result type with Ok value ValueOrOk(Box), /// Result type with Error value @@ -45,9 +45,15 @@ pub enum GoType { /// Slice/array of another type Slice(Box), /// Multi-return type (for functions returning arbitrary multiple values) - // MultiReturn(Vec), + MultiReturn(Vec), /// User-defined type (records, enums, type aliases) UserDefined(String), + /// Resource type (Component Model resources) + Resource(String), + /// Owned handle to a resource + OwnHandle(String), + /// Borrowed handle to a resource + BorrowHandle(String), /// Represents no value/void Nothing, } @@ -87,6 +93,9 @@ impl GoType { | GoType::Float32 | GoType::Float64 => false, + // Resources and handles are represented as integers and don't need cleanup + GoType::Resource(_) | GoType::OwnHandle(_) | GoType::BorrowHandle(_) => false, + // String and slices allocate memory and need cleanup GoType::String | GoType::Slice(_) => true, @@ -120,8 +129,12 @@ impl GoType { // Nothing represents no value, so no cleanup needed GoType::Nothing => false, - // TODO - figure out if a pointer needs cleanup, once implemented. - // GoType::Pointer(_) => todo!() + + // A pointer probably needs cleanup, not sure? + GoType::Pointer(_) => true, + + // Multi-return types need cleanup if their inner types do + GoType::MultiReturn(inner) => inner.iter().any(|t| t.needs_cleanup()), } } } @@ -159,17 +172,41 @@ impl FormatInto for &GoType { tokens.append(static_literal("[]")); typ.as_ref().format_into(tokens); } - // GoType::MultiReturn(typs) => { - // tokens.append(quote!($(for typ in typs join (, ) => $typ))) - // } - // GoType::Pointer(typ) => { - // tokens.append(static_literal("*")); - // typ.as_ref().format_into(tokens); - // } + GoType::MultiReturn(typs) => { + tokens.append(static_literal("struct{")); + if let Some((last, typs)) = typs.split_last() { + for (i, typ) in typs.iter().enumerate() { + let field = GoIdentifier::public(format!("f-{i}")); + field.format_into(tokens); + tokens.space(); + typ.format_into(tokens); + tokens.append(static_literal(";")); + tokens.space(); + } + let field = GoIdentifier::public(format!("f-{}", typs.len())); + field.format_into(tokens); + tokens.space(); + tokens.append(last); + } + tokens.append(static_literal("}")); + } + GoType::Pointer(typ) => { + tokens.append(static_literal("*")); + typ.as_ref().format_into(tokens); + } GoType::UserDefined(name) => { let id = GoIdentifier::public(name); id.format_into(tokens) } + GoType::Resource(name) | GoType::OwnHandle(name) | GoType::BorrowHandle(name) => { + // Handle types (ending with -handle) should use private identifier for lowercase first letter + let id = if name.ends_with("-handle") { + GoIdentifier::private(name) + } else { + GoIdentifier::public(name) + }; + id.format_into(tokens) + } GoType::Nothing => (), } } @@ -243,26 +280,26 @@ mod tests { assert_eq!(tokens.to_string().unwrap(), "[]int32"); } - // #[test] - // fn test_pointer() { - // let typ = GoType::Pointer(Box::new(GoType::String)); - // let mut tokens = Tokens::::new(); - // (&typ).format_into(&mut tokens); - // assert_eq!(tokens.to_string().unwrap(), "*string"); - // } + #[test] + fn test_pointer() { + let typ = GoType::Pointer(Box::new(GoType::String)); + let mut tokens = Tokens::::new(); + (&typ).format_into(&mut tokens); + assert_eq!(tokens.to_string().unwrap(), "*string"); + } - // #[test] - // fn test_nested_types() { - // // Test *[]string - // let typ = GoType::Pointer(Box::new(GoType::Slice(Box::new(GoType::String)))); - // let mut tokens = Tokens::::new(); - // (&typ).format_into(&mut tokens); - // assert_eq!(tokens.to_string().unwrap(), "*[]string"); + #[test] + fn test_nested_types() { + // Test *[]string + let typ = GoType::Pointer(Box::new(GoType::Slice(Box::new(GoType::String)))); + let mut tokens = Tokens::::new(); + (&typ).format_into(&mut tokens); + assert_eq!(tokens.to_string().unwrap(), "*[]string"); - // // Test [][]uint8 - // let typ = GoType::Slice(Box::new(GoType::Slice(Box::new(GoType::Uint8)))); - // let mut tokens = Tokens::::new(); - // (&typ).format_into(&mut tokens); - // assert_eq!(tokens.to_string().unwrap(), "[][]uint8"); - // } + // Test [][]uint8 + let typ = GoType::Slice(Box::new(GoType::Slice(Box::new(GoType::Uint8)))); + let mut tokens = Tokens::::new(); + (&typ).format_into(&mut tokens); + assert_eq!(tokens.to_string().unwrap(), "[][]uint8"); + } } diff --git a/cmd/gravity/src/lib.rs b/cmd/gravity/src/lib.rs index 0b6cc2b..daf8837 100644 --- a/cmd/gravity/src/lib.rs +++ b/cmd/gravity/src/lib.rs @@ -63,13 +63,46 @@ pub fn resolve_type(typ: &Type, resolve: &Resolve) -> GoType { TypeDefKind::Record(_) => { GoType::UserDefined(name.clone().expect("expected record to have a name")) } - TypeDefKind::Resource => todo!("TODO(#5): implement resources"), - TypeDefKind::Handle(_) => todo!("TODO(#5): implement resources"), + TypeDefKind::Resource => { + GoType::Resource(name.clone().expect("expected resource to have a name")) + } + TypeDefKind::Handle(handle) => match handle { + wit_bindgen_core::wit_parser::Handle::Own(id) => { + let resource_def = resolve + .types + .get(*id) + .expect("failed to find resource definition for owned handle"); + let resource_name = resource_def + .name + .clone() + .expect("expected resource to have a name"); + GoType::OwnHandle(resource_name) + } + wit_bindgen_core::wit_parser::Handle::Borrow(id) => { + let resource_def = resolve + .types + .get(*id) + .expect("failed to find resource definition for borrowed handle"); + let resource_name = resource_def + .name + .clone() + .expect("expected resource to have a name"); + GoType::BorrowHandle(resource_name) + } + }, TypeDefKind::Flags(_) => todo!("TODO(#4): implement flag conversion"), - TypeDefKind::Tuple(_) => todo!("TODO(#4): implement tuple conversion"), - // Variants are handled as an empty interfaces in type signatures; however, that - // means they require runtime type reflection - TypeDefKind::Variant(_) => GoType::Interface, + TypeDefKind::Tuple(tuple) => GoType::MultiReturn( + tuple + .types + .iter() + .map(|t| resolve_type(t, resolve)) + .collect(), + ), + TypeDefKind::Variant(_) => GoType::UserDefined( + name.clone() + .clone() + .expect("expected variant to have a name"), + ), TypeDefKind::Enum(_) => { GoType::UserDefined(name.clone().expect("expected enum to have a name")) } diff --git a/cmd/gravity/tests/cmd/basic.stdout b/cmd/gravity/tests/cmd/basic.stdout index 2fe5028..c806c5d 100644 --- a/cmd/gravity/tests/cmd/basic.stdout +++ b/cmd/gravity/tests/cmd/basic.stdout @@ -6,6 +6,7 @@ import "context" import "errors" import "github.com/tetratelabs/wazero" import "github.com/tetratelabs/wazero/api" +import "reflect" import _ "embed" @@ -42,6 +43,7 @@ func NewBasicFactory( ) (*BasicFactory, error) { wazeroRuntime := wazero.NewRuntime(ctx) + // Instantiate import host modules _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:basic/logger"). NewFunctionBuilder(). WithFunc(func( @@ -50,12 +52,17 @@ func NewBasicFactory( arg0 uint32, arg1 uint32, ) { + // GetArg { nth: 0 } + // GetArg { nth: 1 } + // StringLift buf0, ok0 := mod.Memory().Read(arg0, arg1) if !ok0 { panic(errors.New("failed to read bytes from memory")) } str0 := string(buf0) + // CallInterface { func: Function { name: "debug", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } logger.Debug(ctx, str0) + // Return { amt: 0, func: Function { name: "debug", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown } } }). Export("debug"). NewFunctionBuilder(). @@ -65,12 +72,17 @@ func NewBasicFactory( arg0 uint32, arg1 uint32, ) { + // GetArg { nth: 0 } + // GetArg { nth: 1 } + // StringLift buf0, ok0 := mod.Memory().Read(arg0, arg1) if !ok0 { panic(errors.New("failed to read bytes from memory")) } str0 := string(buf0) + // CallInterface { func: Function { name: "info", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } logger.Info(ctx, str0) + // Return { amt: 0, func: Function { name: "info", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown } } }). Export("info"). NewFunctionBuilder(). @@ -80,12 +92,17 @@ func NewBasicFactory( arg0 uint32, arg1 uint32, ) { + // GetArg { nth: 0 } + // GetArg { nth: 1 } + // StringLift buf0, ok0 := mod.Memory().Read(arg0, arg1) if !ok0 { panic(errors.New("failed to read bytes from memory")) } str0 := string(buf0) + // CallInterface { func: Function { name: "warn", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } logger.Warn(ctx, str0) + // Return { amt: 0, func: Function { name: "warn", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown } } }). Export("warn"). NewFunctionBuilder(). @@ -95,12 +112,17 @@ func NewBasicFactory( arg0 uint32, arg1 uint32, ) { + // GetArg { nth: 0 } + // GetArg { nth: 1 } + // StringLift buf0, ok0 := mod.Memory().Read(arg0, arg1) if !ok0 { panic(errors.New("failed to read bytes from memory")) } str0 := string(buf0) + // CallInterface { func: Function { name: "error", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } logger.Error(ctx, str0) + // Return { amt: 0, func: Function { name: "error", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown } } }). Export("error"). Instantiate(ctx) @@ -171,6 +193,7 @@ func writeString( func (i *BasicInstance) Hello( ctx context.Context, ) (string, error) { + // CallWasm { name: "hello", sig: WasmSignature { params: [], results: [Pointer], indirect_params: false, retptr: true } } raw0, err0 := i.module.ExportedFunction("hello").Call(ctx, ) if err0 != nil { var default0 string @@ -190,25 +213,30 @@ func (i *BasicInstance) Hello( }() results0 := raw0[0] + // I32Load8U { offset: 0 } value1, ok1 := i.module.Memory().ReadByte(uint32(results0 + 0)) if !ok1 { var default1 string return default1, errors.New("failed to read byte from memory") } + // ResultLift { result: Result_ { ok: Some(String), err: Some(String) }, ty: Id { idx: 0 } } var value8 string var err8 error switch value1 { case 0: + // PointerLoad { offset: ptrsz } ptr2, ok2 := i.module.Memory().ReadUint32Le(uint32(results0 + 4)) if !ok2 { var default2 string return default2, errors.New("failed to read pointer from memory") } + // LengthLoad { offset: (2*ptrsz) } len3, ok3 := i.module.Memory().ReadUint32Le(uint32(results0 + 8)) if !ok3 { var default3 string return default3, errors.New("failed to read length from memory") } + // StringLift buf4, ok4 := i.module.Memory().Read(ptr2, len3) if !ok4 { var default4 string @@ -217,16 +245,19 @@ func (i *BasicInstance) Hello( str4 := string(buf4) value8 = str4 case 1: + // PointerLoad { offset: ptrsz } ptr5, ok5 := i.module.Memory().ReadUint32Le(uint32(results0 + 4)) if !ok5 { var default5 string return default5, errors.New("failed to read pointer from memory") } + // LengthLoad { offset: (2*ptrsz) } len6, ok6 := i.module.Memory().ReadUint32Le(uint32(results0 + 8)) if !ok6 { var default6 string return default6, errors.New("failed to read length from memory") } + // StringLift buf7, ok7 := i.module.Memory().Read(ptr5, len6) if !ok7 { var default7 string @@ -237,12 +268,15 @@ func (i *BasicInstance) Hello( default: err8 = errors.New("invalid variant discriminant for expected") } + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "hello", kind: Freestanding, params: [], result: Some(Id(Id { idx: 0 })), docs: Docs { contents: None }, stability: Unknown } } return value8, err8 } func (i *BasicInstance) Primitive( ctx context.Context, ) bool { + // CallWasm { name: "primitive", sig: WasmSignature { params: [], results: [I32], indirect_params: false, retptr: false } } raw0, err0 := i.module.ExportedFunction("primitive").Call(ctx, ) // The return type doesn't contain an error so we panic if one is encountered if err0 != nil { @@ -250,45 +284,81 @@ func (i *BasicInstance) Primitive( } results0 := raw0[0] + // BoolFromI32 value1 := results0 != 0 + // Return { amt: 1, func: Function { name: "primitive", kind: Freestanding, params: [], result: Some(Bool), docs: Docs { contents: None }, stability: Unknown } } return value1 } func (i *BasicInstance) OptionalPrimitive( ctx context.Context, + b bool, ) (bool, bool) { - raw0, err0 := i.module.ExportedFunction("optional-primitive").Call(ctx, ) + arg0 := b + // GetArg { nth: 0 } + // OptionLower { payload: Bool, ty: Id { idx: 1 }, results: [I32, I32] } + var variant1_0 uint32 + var variant1_1 uint32 + if reflect.ValueOf(arg0).IsZero() { + // VariantPayloadName + // I32Const { val: 0 } + // ConstZero { tys: [I32] } + variant1_0 = 0 + variant1_1 = 0 + } else { + variantPayload := arg0 + // VariantPayloadName + // I32Const { val: 1 } + // I32FromBool + var value0 uint32 + if variantPayload { + value0 = 1 + } else { + value0 = 0 + } + variant1_0 = 1 + variant1_1 = value0 + } + // CallWasm { name: "optional-primitive", sig: WasmSignature { params: [I32, I32], results: [Pointer], indirect_params: false, retptr: true } } + raw2, err2 := i.module.ExportedFunction("optional-primitive").Call(ctx, uint64(variant1_0), uint64(variant1_1)) // The return type doesn't contain an error so we panic if one is encountered - if err0 != nil { - panic(err0) + if err2 != nil { + panic(err2) } - results0 := raw0[0] - value1, ok1 := i.module.Memory().ReadByte(uint32(results0 + 0)) + results2 := raw2[0] + // I32Load8U { offset: 0 } + value3, ok3 := i.module.Memory().ReadByte(uint32(results2 + 0)) // The return type doesn't contain an error so we panic if one is encountered - if !ok1 { + if !ok3 { panic(errors.New("failed to read byte from memory")) } - var result4 bool - var ok4 bool - if value1 == 0 { - ok4 = false + // OptionLift { payload: Bool, ty: Id { idx: 1 } } + var result6 bool + var ok6 bool + if value3 == 0 { + ok6 = false } else { - value2, ok2 := i.module.Memory().ReadByte(uint32(results0 + 1)) + // I32Load8U { offset: 1 } + value4, ok4 := i.module.Memory().ReadByte(uint32(results2 + 1)) // The return type doesn't contain an error so we panic if one is encountered - if !ok2 { + if !ok4 { panic(errors.New("failed to read byte from memory")) } - value3 := value2 != 0 - ok4 = true - result4 = value3 + // BoolFromI32 + value5 := value4 != 0 + ok6 = true + result6 = value5 } - return result4, ok4 + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "optional-primitive", kind: Freestanding, params: [("b", Id(Id { idx: 1 }))], result: Some(Id(Id { idx: 1 })), docs: Docs { contents: None }, stability: Unknown } } + return result6, ok6 } func (i *BasicInstance) ResultPrimitive( ctx context.Context, ) (bool, error) { + // CallWasm { name: "result-primitive", sig: WasmSignature { params: [], results: [Pointer], indirect_params: false, retptr: true } } raw0, err0 := i.module.ExportedFunction("result-primitive").Call(ctx, ) if err0 != nil { var default0 bool @@ -308,33 +378,40 @@ func (i *BasicInstance) ResultPrimitive( }() results0 := raw0[0] + // I32Load8U { offset: 0 } value1, ok1 := i.module.Memory().ReadByte(uint32(results0 + 0)) if !ok1 { var default1 bool return default1, errors.New("failed to read byte from memory") } + // ResultLift { result: Result_ { ok: Some(Bool), err: Some(String) }, ty: Id { idx: 2 } } var value7 bool var err7 error switch value1 { case 0: + // I32Load8U { offset: ptrsz } value2, ok2 := i.module.Memory().ReadByte(uint32(results0 + 4)) if !ok2 { var default2 bool return default2, errors.New("failed to read byte from memory") } + // BoolFromI32 value3 := value2 != 0 value7 = value3 case 1: + // PointerLoad { offset: ptrsz } ptr4, ok4 := i.module.Memory().ReadUint32Le(uint32(results0 + 4)) if !ok4 { var default4 bool return default4, errors.New("failed to read pointer from memory") } + // LengthLoad { offset: (2*ptrsz) } len5, ok5 := i.module.Memory().ReadUint32Le(uint32(results0 + 8)) if !ok5 { var default5 bool return default5, errors.New("failed to read length from memory") } + // StringLift buf6, ok6 := i.module.Memory().Read(ptr4, len5) if !ok6 { var default6 bool @@ -345,6 +422,100 @@ func (i *BasicInstance) ResultPrimitive( default: err7 = errors.New("invalid variant discriminant for expected") } + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "result-primitive", kind: Freestanding, params: [], result: Some(Id(Id { idx: 2 })), docs: Docs { contents: None }, stability: Unknown } } return value7, err7 } +func (i *BasicInstance) OptionalString( + ctx context.Context, + s string, +) (string, bool) { + arg0 := s + // GetArg { nth: 0 } + // OptionLower { payload: String, ty: Id { idx: 3 }, results: [I32, Pointer, Length] } + var variant1_0 uint32 + var variant1_1 uint64 + var variant1_2 uint64 + if reflect.ValueOf(arg0).IsZero() { + // VariantPayloadName + // I32Const { val: 0 } + // ConstZero { tys: [Pointer, Length] } + variant1_0 = 0 + variant1_1 = 0 + variant1_2 = 0 + } else { + variantPayload := arg0 + // VariantPayloadName + // I32Const { val: 1 } + // StringLower { realloc: Some("cabi_realloc") } + memory0 := i.module.Memory() + realloc0 := i.module.ExportedFunction("cabi_realloc") + ptr0, len0, err0 := writeString(ctx, variantPayload, memory0, realloc0) + // The return type doesn't contain an error so we panic if one is encountered + if err0 != nil { + panic(err0) + } + variant1_0 = 1 + variant1_1 = ptr0 + variant1_2 = len0 + } + // CallWasm { name: "optional-string", sig: WasmSignature { params: [I32, Pointer, Length], results: [Pointer], indirect_params: false, retptr: true } } + raw2, err2 := i.module.ExportedFunction("optional-string").Call(ctx, uint64(variant1_0), uint64(variant1_1), uint64(variant1_2)) + // The return type doesn't contain an error so we panic if one is encountered + if err2 != nil { + panic(err2) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if _, err := i.module.ExportedFunction("cabi_post_optional-string").Call(ctx, raw2...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + }() + + results2 := raw2[0] + // I32Load8U { offset: 0 } + value3, ok3 := i.module.Memory().ReadByte(uint32(results2 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok3 { + panic(errors.New("failed to read byte from memory")) + } + // OptionLift { payload: String, ty: Id { idx: 3 } } + var result7 string + var ok7 bool + if value3 == 0 { + ok7 = false + } else { + // PointerLoad { offset: ptrsz } + ptr4, ok4 := i.module.Memory().ReadUint32Le(uint32(results2 + 4)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok4 { + panic(errors.New("failed to read pointer from memory")) + } + // LengthLoad { offset: (2*ptrsz) } + len5, ok5 := i.module.Memory().ReadUint32Le(uint32(results2 + 8)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok5 { + panic(errors.New("failed to read length from memory")) + } + // StringLift + buf6, ok6 := i.module.Memory().Read(ptr4, len5) + // The return type doesn't contain an error so we panic if one is encountered + if !ok6 { + panic(errors.New("failed to read bytes from memory")) + } + str6 := string(buf6) + ok7 = true + result7 = str6 + } + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "optional-string", kind: Freestanding, params: [("s", Id(Id { idx: 3 }))], result: Some(Id(Id { idx: 3 })), docs: Docs { contents: None }, stability: Unknown } } + return result7, ok7 +} + diff --git a/cmd/gravity/tests/cmd/iface-method-returns-string.stdout b/cmd/gravity/tests/cmd/iface-method-returns-string.stdout index 83f4cb0..c8a72ab 100644 --- a/cmd/gravity/tests/cmd/iface-method-returns-string.stdout +++ b/cmd/gravity/tests/cmd/iface-method-returns-string.stdout @@ -36,6 +36,7 @@ func NewExampleFactory( ) (*ExampleFactory, error) { wazeroRuntime := wazero.NewRuntime(ctx) + // Instantiate import host modules _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:example/runtime"). NewFunctionBuilder(). WithFunc(func( @@ -43,15 +44,21 @@ func NewExampleFactory( mod api.Module, arg0 uint32, ) { + // CallInterface { func: Function { name: "os", kind: Freestanding, params: [], result: Some(String), docs: Docs { contents: None }, stability: Unknown }, async_: false } value0 := runtime.Os(ctx, ) + // GetArg { nth: 0 } + // StringLower { realloc: Some("cabi_realloc") } memory1 := mod.Memory() realloc1 := mod.ExportedFunction("cabi_realloc") ptr1, len1, err1 := writeString(ctx, value0, memory1, realloc1) if err1 != nil { panic(err1) } - mod.Memory().WriteUint32Le(arg0+4, uint32(len1)) - mod.Memory().WriteUint32Le(arg0+0, uint32(ptr1)) + // LengthStore { offset: ptrsz } + mod.Memory().WriteUint32Le(uint32(arg0+4), uint32(len1)) + // PointerStore { offset: 0 } + mod.Memory().WriteUint32Le(uint32(arg0+0), uint32(ptr1)) + // Return { amt: 0, func: Function { name: "os", kind: Freestanding, params: [], result: Some(String), docs: Docs { contents: None }, stability: Unknown } } }). Export("os"). NewFunctionBuilder(). @@ -60,15 +67,21 @@ func NewExampleFactory( mod api.Module, arg0 uint32, ) { + // CallInterface { func: Function { name: "arch", kind: Freestanding, params: [], result: Some(String), docs: Docs { contents: None }, stability: Unknown }, async_: false } value0 := runtime.Arch(ctx, ) + // GetArg { nth: 0 } + // StringLower { realloc: Some("cabi_realloc") } memory1 := mod.Memory() realloc1 := mod.ExportedFunction("cabi_realloc") ptr1, len1, err1 := writeString(ctx, value0, memory1, realloc1) if err1 != nil { panic(err1) } - mod.Memory().WriteUint32Le(arg0+4, uint32(len1)) - mod.Memory().WriteUint32Le(arg0+0, uint32(ptr1)) + // LengthStore { offset: ptrsz } + mod.Memory().WriteUint32Le(uint32(arg0+4), uint32(len1)) + // PointerStore { offset: 0 } + mod.Memory().WriteUint32Le(uint32(arg0+0), uint32(ptr1)) + // Return { amt: 0, func: Function { name: "arch", kind: Freestanding, params: [], result: Some(String), docs: Docs { contents: None }, stability: Unknown } } }). Export("arch"). NewFunctionBuilder(). @@ -78,12 +91,17 @@ func NewExampleFactory( arg0 uint32, arg1 uint32, ) { + // GetArg { nth: 0 } + // GetArg { nth: 1 } + // StringLift buf0, ok0 := mod.Memory().Read(arg0, arg1) if !ok0 { panic(errors.New("failed to read bytes from memory")) } str0 := string(buf0) + // CallInterface { func: Function { name: "puts", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } runtime.Puts(ctx, str0) + // Return { amt: 0, func: Function { name: "puts", kind: Freestanding, params: [("msg", String)], result: None, docs: Docs { contents: None }, stability: Unknown } } }). Export("puts"). Instantiate(ctx) @@ -154,6 +172,7 @@ func writeString( func (i *ExampleInstance) Hello( ctx context.Context, ) (string, error) { + // CallWasm { name: "hello", sig: WasmSignature { params: [], results: [Pointer], indirect_params: false, retptr: true } } raw0, err0 := i.module.ExportedFunction("hello").Call(ctx, ) if err0 != nil { var default0 string @@ -173,25 +192,30 @@ func (i *ExampleInstance) Hello( }() results0 := raw0[0] + // I32Load8U { offset: 0 } value1, ok1 := i.module.Memory().ReadByte(uint32(results0 + 0)) if !ok1 { var default1 string return default1, errors.New("failed to read byte from memory") } + // ResultLift { result: Result_ { ok: Some(String), err: Some(String) }, ty: Id { idx: 0 } } var value8 string var err8 error switch value1 { case 0: + // PointerLoad { offset: ptrsz } ptr2, ok2 := i.module.Memory().ReadUint32Le(uint32(results0 + 4)) if !ok2 { var default2 string return default2, errors.New("failed to read pointer from memory") } + // LengthLoad { offset: (2*ptrsz) } len3, ok3 := i.module.Memory().ReadUint32Le(uint32(results0 + 8)) if !ok3 { var default3 string return default3, errors.New("failed to read length from memory") } + // StringLift buf4, ok4 := i.module.Memory().Read(ptr2, len3) if !ok4 { var default4 string @@ -200,16 +224,19 @@ func (i *ExampleInstance) Hello( str4 := string(buf4) value8 = str4 case 1: + // PointerLoad { offset: ptrsz } ptr5, ok5 := i.module.Memory().ReadUint32Le(uint32(results0 + 4)) if !ok5 { var default5 string return default5, errors.New("failed to read pointer from memory") } + // LengthLoad { offset: (2*ptrsz) } len6, ok6 := i.module.Memory().ReadUint32Le(uint32(results0 + 8)) if !ok6 { var default6 string return default6, errors.New("failed to read length from memory") } + // StringLift buf7, ok7 := i.module.Memory().Read(ptr5, len6) if !ok7 { var default7 string @@ -220,6 +247,8 @@ func (i *ExampleInstance) Hello( default: err8 = errors.New("invalid variant discriminant for expected") } + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "hello", kind: Freestanding, params: [], result: Some(Id(Id { idx: 0 })), docs: Docs { contents: None }, stability: Unknown } } return value8, err8 } diff --git a/cmd/gravity/tests/cmd/instructions.stdout b/cmd/gravity/tests/cmd/instructions.stdout index a20c965..7c7e41d 100644 --- a/cmd/gravity/tests/cmd/instructions.stdout +++ b/cmd/gravity/tests/cmd/instructions.stdout @@ -22,6 +22,8 @@ func NewInstructionsFactory( ) (*InstructionsFactory, error) { wazeroRuntime := wazero.NewRuntime(ctx) + // Instantiate import host modules + // Compiling the module takes a LONG time, so we want to do it once and hold // onto it with the Runtime module, err := wazeroRuntime.CompileModule(ctx, wasmFileInstructions) @@ -87,7 +89,10 @@ func (i *InstructionsInstance) S8Roundtrip( val int8, ) int8 { arg0 := val + // GetArg { nth: 0 } + // I32FromS8 value0 := api.EncodeI32(int32(arg0)) + // CallWasm { name: "s8-roundtrip", sig: WasmSignature { params: [I32], results: [I32], indirect_params: false, retptr: false } } raw1, err1 := i.module.ExportedFunction("s8-roundtrip").Call(ctx, uint64(value0)) // The return type doesn't contain an error so we panic if one is encountered if err1 != nil { @@ -95,7 +100,9 @@ func (i *InstructionsInstance) S8Roundtrip( } results1 := raw1[0] + // S8FromI32 result2 := int8(api.DecodeI32(results1)) + // Return { amt: 1, func: Function { name: "s8-roundtrip", kind: Freestanding, params: [("val", S8)], result: Some(S8), docs: Docs { contents: None }, stability: Unknown } } return result2 } @@ -104,7 +111,10 @@ func (i *InstructionsInstance) U8Roundtrip( val uint8, ) uint8 { arg0 := val + // GetArg { nth: 0 } + // I32FromU8 value0 := api.EncodeI32(int32(arg0)) + // CallWasm { name: "u8-roundtrip", sig: WasmSignature { params: [I32], results: [I32], indirect_params: false, retptr: false } } raw1, err1 := i.module.ExportedFunction("u8-roundtrip").Call(ctx, uint64(value0)) // The return type doesn't contain an error so we panic if one is encountered if err1 != nil { @@ -112,7 +122,9 @@ func (i *InstructionsInstance) U8Roundtrip( } results1 := raw1[0] + // U8FromI32 result2 := uint8(api.DecodeU32(results1)) + // Return { amt: 1, func: Function { name: "u8-roundtrip", kind: Freestanding, params: [("val", U8)], result: Some(U8), docs: Docs { contents: None }, stability: Unknown } } return result2 } @@ -121,7 +133,10 @@ func (i *InstructionsInstance) S16Roundtrip( val int16, ) int16 { arg0 := val + // GetArg { nth: 0 } + // I32FromS16 value0 := api.EncodeI32(int32(arg0)) + // CallWasm { name: "s16-roundtrip", sig: WasmSignature { params: [I32], results: [I32], indirect_params: false, retptr: false } } raw1, err1 := i.module.ExportedFunction("s16-roundtrip").Call(ctx, uint64(value0)) // The return type doesn't contain an error so we panic if one is encountered if err1 != nil { @@ -129,7 +144,9 @@ func (i *InstructionsInstance) S16Roundtrip( } results1 := raw1[0] + // S16FromI32 result2 := int16(api.DecodeI32(results1)) + // Return { amt: 1, func: Function { name: "s16-roundtrip", kind: Freestanding, params: [("val", S16)], result: Some(S16), docs: Docs { contents: None }, stability: Unknown } } return result2 } @@ -138,7 +155,10 @@ func (i *InstructionsInstance) U16Roundtrip( val uint16, ) uint16 { arg0 := val + // GetArg { nth: 0 } + // I32FromU16 value0 := api.EncodeI32(int32(arg0)) + // CallWasm { name: "u16-roundtrip", sig: WasmSignature { params: [I32], results: [I32], indirect_params: false, retptr: false } } raw1, err1 := i.module.ExportedFunction("u16-roundtrip").Call(ctx, uint64(value0)) // The return type doesn't contain an error so we panic if one is encountered if err1 != nil { @@ -146,7 +166,9 @@ func (i *InstructionsInstance) U16Roundtrip( } results1 := raw1[0] + // U16FromI32 result2 := uint16(api.DecodeU32(results1)) + // Return { amt: 1, func: Function { name: "u16-roundtrip", kind: Freestanding, params: [("val", U16)], result: Some(U16), docs: Docs { contents: None }, stability: Unknown } } return result2 } @@ -155,7 +177,10 @@ func (i *InstructionsInstance) S32Roundtrip( val int32, ) int32 { arg0 := val + // GetArg { nth: 0 } + // I32FromS32 value0 := api.EncodeI32(arg0) + // CallWasm { name: "s32-roundtrip", sig: WasmSignature { params: [I32], results: [I32], indirect_params: false, retptr: false } } raw1, err1 := i.module.ExportedFunction("s32-roundtrip").Call(ctx, uint64(value0)) // The return type doesn't contain an error so we panic if one is encountered if err1 != nil { @@ -163,7 +188,9 @@ func (i *InstructionsInstance) S32Roundtrip( } results1 := raw1[0] + // S32FromI32 result2 := api.DecodeI32(results1) + // Return { amt: 1, func: Function { name: "s32-roundtrip", kind: Freestanding, params: [("val", S32)], result: Some(S32), docs: Docs { contents: None }, stability: Unknown } } return result2 } @@ -172,7 +199,10 @@ func (i *InstructionsInstance) U32Roundtrip( val uint32, ) uint32 { arg0 := val + // GetArg { nth: 0 } + // I32FromU32 result0 := api.EncodeU32(arg0) + // CallWasm { name: "u32-roundtrip", sig: WasmSignature { params: [I32], results: [I32], indirect_params: false, retptr: false } } raw1, err1 := i.module.ExportedFunction("u32-roundtrip").Call(ctx, uint64(result0)) // The return type doesn't contain an error so we panic if one is encountered if err1 != nil { @@ -180,7 +210,53 @@ func (i *InstructionsInstance) U32Roundtrip( } results1 := raw1[0] - result2 := api.DecodeU32(results1) + // U32FromI32 + result2 := api.DecodeU32(uint64(results1)) + // Return { amt: 1, func: Function { name: "u32-roundtrip", kind: Freestanding, params: [("val", U32)], result: Some(U32), docs: Docs { contents: None }, stability: Unknown } } + return result2 +} + +func (i *InstructionsInstance) F32Roundtrip( + ctx context.Context, + val float32, +) float32 { + arg0 := val + // GetArg { nth: 0 } + // CoreF32FromF32 + result0 := api.EncodeF32(float32(arg0)) + // CallWasm { name: "f32-roundtrip", sig: WasmSignature { params: [F32], results: [F32], indirect_params: false, retptr: false } } + raw1, err1 := i.module.ExportedFunction("f32-roundtrip").Call(ctx, uint64(result0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + results1 := raw1[0] + // F32FromCoreF32 + result2 := api.DecodeF32(results1) + // Return { amt: 1, func: Function { name: "f32-roundtrip", kind: Freestanding, params: [("val", F32)], result: Some(F32), docs: Docs { contents: None }, stability: Unknown } } + return result2 +} + +func (i *InstructionsInstance) F64Roundtrip( + ctx context.Context, + val float64, +) float64 { + arg0 := val + // GetArg { nth: 0 } + // CoreF64FromF64 + result0 := api.EncodeF64(float64(arg0)) + // CallWasm { name: "f64-roundtrip", sig: WasmSignature { params: [F64], results: [F64], indirect_params: false, retptr: false } } + raw1, err1 := i.module.ExportedFunction("f64-roundtrip").Call(ctx, uint64(result0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + results1 := raw1[0] + // F64FromCoreF64 + result2 := api.DecodeF64(results1) + // Return { amt: 1, func: Function { name: "f64-roundtrip", kind: Freestanding, params: [("val", F64)], result: Some(F64), docs: Docs { contents: None }, stability: Unknown } } return result2 } diff --git a/cmd/gravity/tests/cmd/outlier.stderr b/cmd/gravity/tests/cmd/outlier.stderr new file mode 100644 index 0000000..e69de29 diff --git a/cmd/gravity/tests/cmd/outlier.stdout b/cmd/gravity/tests/cmd/outlier.stdout new file mode 100644 index 0000000..99f8d95 --- /dev/null +++ b/cmd/gravity/tests/cmd/outlier.stdout @@ -0,0 +1,591 @@ +// Code generated by arcjet-gravity; DO NOT EDIT. + +package outlier + +import "context" +import "errors" +import "github.com/tetratelabs/wazero" +import "github.com/tetratelabs/wazero/api" + +import _ "embed" + +//go:embed outlier.wasm +var wasmFileOutlier []byte + +type IOutlierTypes interface {} + +type OutlierInterval struct { + Start uint32 + + End *uint32 +} + +type Series struct { + IsOutlier bool + + OutlierIntervals []OutlierInterval + + Scores []float64 +} + +type Band struct { + Min []float64 + + Max []float64 +} + +type Output struct { + OutlyingSeries []uint32 + + SeriesResults []Series + + ClusterBand *Band +} + +type Error = string + +type EpsilonOrSensitivity interface { + isEpsilonOrSensitivity() +} + +type EpsilonOrSensitivitySensitivity float64 +func (EpsilonOrSensitivitySensitivity) isEpsilonOrSensitivity() {} + +type EpsilonOrSensitivityEpsilon float64 +func (EpsilonOrSensitivityEpsilon) isEpsilonOrSensitivity() {} + +type DbscanParams struct { + EpsilonOrSensitivity EpsilonOrSensitivity +} + +type ThresholdOrSensitivity interface { + isThresholdOrSensitivity() +} + +type ThresholdOrSensitivitySensitivity float64 +func (ThresholdOrSensitivitySensitivity) isThresholdOrSensitivity() {} + +type ThresholdOrSensitivityThreshold float64 +func (ThresholdOrSensitivityThreshold) isThresholdOrSensitivity() {} + +type MadParams struct { + ThresholdOrSensitivity ThresholdOrSensitivity +} + +type Algorithm interface { + isAlgorithm() +} + +type AlgorithmDbscan DbscanParams +func (AlgorithmDbscan) isAlgorithm() {} + +type AlgorithmMad MadParams +func (AlgorithmMad) isAlgorithm() {} + +type Input struct { + Data [][]float64 + + Algorithm Algorithm +} + +type OutlierFactory struct { + runtime wazero.Runtime + module wazero.CompiledModule +} + +func NewOutlierFactory( + ctx context.Context, + types IOutlierTypes, +) (*OutlierFactory, error) { + wazeroRuntime := wazero.NewRuntime(ctx) + + // Instantiate import host modules + _, err0 := wazeroRuntime.NewHostModuleBuilder("augurs:outlier/types"). + Instantiate(ctx) + if err0 != nil { + return nil, err0 + } + + // Compiling the module takes a LONG time, so we want to do it once and hold + // onto it with the Runtime + module, err := wazeroRuntime.CompileModule(ctx, wasmFileOutlier) + if err != nil { + return nil, err + } + return &OutlierFactory{ + runtime: wazeroRuntime, + module: module, + }, nil +} + +func (f *OutlierFactory) Instantiate(ctx context.Context) (*OutlierInstance, error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, wazero.NewModuleConfig()); err != nil { + return nil, err + } else { + return &OutlierInstance{module}, nil + } +} + +func (f *OutlierFactory) Close(ctx context.Context) { + f.runtime.Close(ctx) +} + +type OutlierInstance struct { + module api.Module +} + +func (i *OutlierInstance) Close(ctx context.Context) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil +} + +// writeString will put a Go string into the Wasm memory following the Component +// Model calling convetions, such as allocating memory with the realloc function +func writeString( + ctx context.Context, + s string, + memory api.Memory, + realloc api.Function, +) (uint64, uint64, error) { + if len(s) == 0 { + return 1, 0, nil + } + + results, err := realloc.Call(ctx, 0, 0, 1, uint64(len(s))) + if err != nil { + return 1, 0, err + } + ptr := results[0] + ok := memory.Write(uint32(ptr), []byte(s)) + if !ok { + return 1, 0, err + } + return uint64(ptr), uint64(len(s)), nil +} + +func (i *OutlierInstance) Detect( + ctx context.Context, + input Input, +) (Output, error) { + arg0 := input + // GetArg { nth: 0 } + // RecordLower { record: Record { fields: [Field { name: "data", ty: Id(Id { idx: 16 }), docs: Docs { contents: None } }, Field { name: "algorithm", ty: Id(Id { idx: 15 }), docs: Docs { contents: None } }] }, name: "input", ty: Id { idx: 17 } } + data0 := arg0.Data + algorithm0 := arg0.Algorithm + // ListLower { element: Id(Id { idx: 3 }), realloc: Some("cabi_realloc") } + vec3 := data0 + len3 := uint64(len(vec3)) + result3, err3 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len3 * 8) + if err3 != nil { + var default3 Output + return default3, err3 + } + ptr3 := result3[0] + for idx := uint64(0); idx < len3; idx++ { + e := vec3[idx] + base := uint32(ptr3 + uint64(idx) * uint64(8)) + // IterElem { element: Id(Id { idx: 3 }) } + // IterBasePointer + // ListLower { element: F64, realloc: Some("cabi_realloc") } + vec2 := e + len2 := uint64(len(vec2)) + result2, err2 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 8, len2 * 8) + if err2 != nil { + var default2 Output + return default2, err2 + } + ptr2 := result2[0] + for idx := uint64(0); idx < len2; idx++ { + e := vec2[idx] + base := uint32(ptr2 + uint64(idx) * uint64(8)) + // IterElem { element: F64 } + // IterBasePointer + // CoreF64FromF64 + result1 := api.EncodeF64(float64(e)) + // F64Store { offset: 0 } + i.module.Memory().WriteUint64Le(uint32(base+0), result1) + } + // LengthStore { offset: ptrsz } + i.module.Memory().WriteUint32Le(uint32(base+4), uint32(len2)) + // PointerStore { offset: 0 } + i.module.Memory().WriteUint32Le(uint32(base+0), uint32(ptr2)) + } + // VariantLower { variant: Variant { cases: [Case { name: "dbscan", ty: Some(Id(Id { idx: 12 })), docs: Docs { contents: None } }, Case { name: "mad", ty: Some(Id(Id { idx: 14 })), docs: Docs { contents: None } }] }, name: "algorithm", ty: Id { idx: 15 }, results: [I32, I32, F64] } + var variant12_0 uint64 + var variant12_1 uint64 + var variant12_2 uint64 + switch variantPayload := algorithm0.(type) { + case AlgorithmDbscan: + // VariantPayloadName + // I32Const { val: 0 } + // RecordLower { record: Record { fields: [Field { name: "epsilon-or-sensitivity", ty: Id(Id { idx: 11 }), docs: Docs { contents: None } }] }, name: "dbscan-params", ty: Id { idx: 12 } } + epsilonOrSensitivity4 := variantPayload.EpsilonOrSensitivity + // VariantLower { variant: Variant { cases: [Case { name: "sensitivity", ty: Some(F64), docs: Docs { contents: None } }, Case { name: "epsilon", ty: Some(F64), docs: Docs { contents: None } }] }, name: "epsilon-or-sensitivity", ty: Id { idx: 11 }, results: [I32, F64] } + var variant7_0 uint64 + var variant7_1 uint64 + switch variantPayload := epsilonOrSensitivity4.(type) { + case EpsilonOrSensitivitySensitivity: + // VariantPayloadName + // I32Const { val: 0 } + // CoreF64FromF64 + result5 := api.EncodeF64(float64(variantPayload)) + variant7_0 = 0 + variant7_1 = result5 + case EpsilonOrSensitivityEpsilon: + // VariantPayloadName + // I32Const { val: 1 } + // CoreF64FromF64 + result6 := api.EncodeF64(float64(variantPayload)) + variant7_0 = 1 + variant7_1 = result6 + default: + var default7 Output + return default7, errors.New("invalid variant type provided") + } + variant12_0 = 0 + variant12_1 = variant7_0 + variant12_2 = variant7_1 + case AlgorithmMad: + // VariantPayloadName + // I32Const { val: 1 } + // RecordLower { record: Record { fields: [Field { name: "threshold-or-sensitivity", ty: Id(Id { idx: 13 }), docs: Docs { contents: None } }] }, name: "mad-params", ty: Id { idx: 14 } } + thresholdOrSensitivity8 := variantPayload.ThresholdOrSensitivity + // VariantLower { variant: Variant { cases: [Case { name: "sensitivity", ty: Some(F64), docs: Docs { contents: None } }, Case { name: "threshold", ty: Some(F64), docs: Docs { contents: None } }] }, name: "threshold-or-sensitivity", ty: Id { idx: 13 }, results: [I32, F64] } + var variant11_0 uint64 + var variant11_1 uint64 + switch variantPayload := thresholdOrSensitivity8.(type) { + case ThresholdOrSensitivitySensitivity: + // VariantPayloadName + // I32Const { val: 0 } + // CoreF64FromF64 + result9 := api.EncodeF64(float64(variantPayload)) + variant11_0 = 0 + variant11_1 = result9 + case ThresholdOrSensitivityThreshold: + // VariantPayloadName + // I32Const { val: 1 } + // CoreF64FromF64 + result10 := api.EncodeF64(float64(variantPayload)) + variant11_0 = 1 + variant11_1 = result10 + default: + var default11 Output + return default11, errors.New("invalid variant type provided") + } + variant12_0 = 1 + variant12_1 = variant11_0 + variant12_2 = variant11_1 + default: + var default12 Output + return default12, errors.New("invalid variant type provided") + } + // CallWasm { name: "detect", sig: WasmSignature { params: [Pointer, Length, I32, I32, F64], results: [Pointer], indirect_params: false, retptr: true } } + raw13, err13 := i.module.ExportedFunction("detect").Call(ctx, uint64(ptr3), uint64(len3), uint64(variant12_0), uint64(variant12_1), uint64(variant12_2)) + if err13 != nil { + var default13 Output + return default13, err13 + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if _, err := i.module.ExportedFunction("cabi_post_detect").Call(ctx, raw13...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + }() + + results13 := raw13[0] + // I32Load8U { offset: 0 } + value14, ok14 := i.module.Memory().ReadByte(uint32(results13 + 0)) + if !ok14 { + var default14 Output + return default14, errors.New("failed to read byte from memory") + } + // ResultLift { result: Result_ { ok: Some(Id(Id { idx: 19 })), err: Some(String) }, ty: Id { idx: 20 } } + var value58 Output + var err58 error + switch value14 { + case 0: + // PointerLoad { offset: ptrsz } + ptr15, ok15 := i.module.Memory().ReadUint32Le(uint32(results13 + 4)) + if !ok15 { + var default15 Output + return default15, errors.New("failed to read pointer from memory") + } + // LengthLoad { offset: (2*ptrsz) } + len16, ok16 := i.module.Memory().ReadUint32Le(uint32(results13 + 8)) + if !ok16 { + var default16 Output + return default16, errors.New("failed to read length from memory") + } + // ListLift { element: U32, ty: Id { idx: 6 } } + base19 := ptr15 + len19 := len16 + result19 := make([]uint32, len19) + for idx19 := uint32(0); idx19 < len19; idx19++ { + base := base19 + idx19 * 4 + // IterBasePointer + // I32Load { offset: 0 } + value17, ok17 := i.module.Memory().ReadUint32Le(uint32(base + 0)) + if !ok17 { + var default17 Output + return default17, errors.New("failed to read i32 from memory") + } + // U32FromI32 + result18 := api.DecodeU32(uint64(value17)) + result19[idx19] = result18 + } + // PointerLoad { offset: (3*ptrsz) } + ptr20, ok20 := i.module.Memory().ReadUint32Le(uint32(results13 + 12)) + if !ok20 { + var default20 Output + return default20, errors.New("failed to read pointer from memory") + } + // LengthLoad { offset: (4*ptrsz) } + len21, ok21 := i.module.Memory().ReadUint32Le(uint32(results13 + 16)) + if !ok21 { + var default21 Output + return default21, errors.New("failed to read length from memory") + } + // ListLift { element: Id(Id { idx: 4 }), ty: Id { idx: 7 } } + base40 := ptr20 + len40 := len21 + result40 := make([]Series, len40) + for idx40 := uint32(0); idx40 < len40; idx40++ { + base := base40 + idx40 * 20 + // IterBasePointer + // I32Load8U { offset: 0 } + value22, ok22 := i.module.Memory().ReadByte(uint32(base + 0)) + if !ok22 { + var default22 Output + return default22, errors.New("failed to read byte from memory") + } + // BoolFromI32 + value23 := value22 != 0 + // PointerLoad { offset: ptrsz } + ptr24, ok24 := i.module.Memory().ReadUint32Le(uint32(base + 4)) + if !ok24 { + var default24 Output + return default24, errors.New("failed to read pointer from memory") + } + // LengthLoad { offset: (2*ptrsz) } + len25, ok25 := i.module.Memory().ReadUint32Le(uint32(base + 8)) + if !ok25 { + var default25 Output + return default25, errors.New("failed to read length from memory") + } + // ListLift { element: Id(Id { idx: 1 }), ty: Id { idx: 2 } } + base33 := ptr24 + len33 := len25 + result33 := make([]OutlierInterval, len33) + for idx33 := uint32(0); idx33 < len33; idx33++ { + base := base33 + idx33 * 12 + // IterBasePointer + // I32Load { offset: 0 } + value26, ok26 := i.module.Memory().ReadUint32Le(uint32(base + 0)) + if !ok26 { + var default26 Output + return default26, errors.New("failed to read i32 from memory") + } + // U32FromI32 + result27 := api.DecodeU32(uint64(value26)) + // I32Load8U { offset: 4 } + value28, ok28 := i.module.Memory().ReadByte(uint32(base + 4)) + if !ok28 { + var default28 Output + return default28, errors.New("failed to read byte from memory") + } + // OptionLift { payload: U32, ty: Id { idx: 0 } } + var result31 uint32 + var ok31 bool + if value28 == 0 { + ok31 = false + } else { + // I32Load { offset: 8 } + value29, ok29 := i.module.Memory().ReadUint32Le(uint32(base + 8)) + if !ok29 { + var default29 Output + return default29, errors.New("failed to read i32 from memory") + } + // U32FromI32 + result30 := api.DecodeU32(uint64(value29)) + ok31 = true + result31 = result30 + } + // RecordLift { record: Record { fields: [Field { name: "start", ty: U32, docs: Docs { contents: None } }, Field { name: "end", ty: Id(Id { idx: 0 }), docs: Docs { contents: None } }] }, name: "outlier-interval", ty: Id { idx: 1 } } + var ptr32x1 *uint32 + if ok31 { + ptr32x1 = &result31 + } else { + ptr32x1 = nil + } + value32 := OutlierInterval{ + Start: result27, + End: ptr32x1, + } + result33[idx33] = value32 + } + // PointerLoad { offset: (3*ptrsz) } + ptr34, ok34 := i.module.Memory().ReadUint32Le(uint32(base + 12)) + if !ok34 { + var default34 Output + return default34, errors.New("failed to read pointer from memory") + } + // LengthLoad { offset: (4*ptrsz) } + len35, ok35 := i.module.Memory().ReadUint32Le(uint32(base + 16)) + if !ok35 { + var default35 Output + return default35, errors.New("failed to read length from memory") + } + // ListLift { element: F64, ty: Id { idx: 3 } } + base38 := ptr34 + len38 := len35 + result38 := make([]float64, len38) + for idx38 := uint32(0); idx38 < len38; idx38++ { + base := base38 + idx38 * 8 + // IterBasePointer + // F64Load { offset: 0 } + value36, ok36 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + if !ok36 { + var default36 Output + return default36, errors.New("failed to read f64 from memory") + } + // F64FromCoreF64 + result37 := api.DecodeF64(value36) + result38[idx38] = result37 + } + // RecordLift { record: Record { fields: [Field { name: "is-outlier", ty: Bool, docs: Docs { contents: None } }, Field { name: "outlier-intervals", ty: Id(Id { idx: 2 }), docs: Docs { contents: None } }, Field { name: "scores", ty: Id(Id { idx: 3 }), docs: Docs { contents: None } }] }, name: "series", ty: Id { idx: 4 } } + value39 := Series{ + IsOutlier: value23, + OutlierIntervals: result33, + Scores: result38, + } + result40[idx40] = value39 + } + // I32Load8U { offset: (5*ptrsz) } + value41, ok41 := i.module.Memory().ReadByte(uint32(results13 + 20)) + if !ok41 { + var default41 Output + return default41, errors.New("failed to read byte from memory") + } + // OptionLift { payload: Id(Id { idx: 5 }), ty: Id { idx: 8 } } + var result53 Band + var ok53 bool + if value41 == 0 { + ok53 = false + } else { + // PointerLoad { offset: (6*ptrsz) } + ptr42, ok42 := i.module.Memory().ReadUint32Le(uint32(results13 + 24)) + if !ok42 { + var default42 Output + return default42, errors.New("failed to read pointer from memory") + } + // LengthLoad { offset: (7*ptrsz) } + len43, ok43 := i.module.Memory().ReadUint32Le(uint32(results13 + 28)) + if !ok43 { + var default43 Output + return default43, errors.New("failed to read length from memory") + } + // ListLift { element: F64, ty: Id { idx: 3 } } + base46 := ptr42 + len46 := len43 + result46 := make([]float64, len46) + for idx46 := uint32(0); idx46 < len46; idx46++ { + base := base46 + idx46 * 8 + // IterBasePointer + // F64Load { offset: 0 } + value44, ok44 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + if !ok44 { + var default44 Output + return default44, errors.New("failed to read f64 from memory") + } + // F64FromCoreF64 + result45 := api.DecodeF64(value44) + result46[idx46] = result45 + } + // PointerLoad { offset: (8*ptrsz) } + ptr47, ok47 := i.module.Memory().ReadUint32Le(uint32(results13 + 32)) + if !ok47 { + var default47 Output + return default47, errors.New("failed to read pointer from memory") + } + // LengthLoad { offset: (9*ptrsz) } + len48, ok48 := i.module.Memory().ReadUint32Le(uint32(results13 + 36)) + if !ok48 { + var default48 Output + return default48, errors.New("failed to read length from memory") + } + // ListLift { element: F64, ty: Id { idx: 3 } } + base51 := ptr47 + len51 := len48 + result51 := make([]float64, len51) + for idx51 := uint32(0); idx51 < len51; idx51++ { + base := base51 + idx51 * 8 + // IterBasePointer + // F64Load { offset: 0 } + value49, ok49 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + if !ok49 { + var default49 Output + return default49, errors.New("failed to read f64 from memory") + } + // F64FromCoreF64 + result50 := api.DecodeF64(value49) + result51[idx51] = result50 + } + // RecordLift { record: Record { fields: [Field { name: "min", ty: Id(Id { idx: 3 }), docs: Docs { contents: None } }, Field { name: "max", ty: Id(Id { idx: 3 }), docs: Docs { contents: None } }] }, name: "band", ty: Id { idx: 5 } } + value52 := Band{ + Min: result46, + Max: result51, + } + ok53 = true + result53 = value52 + } + // RecordLift { record: Record { fields: [Field { name: "outlying-series", ty: Id(Id { idx: 6 }), docs: Docs { contents: None } }, Field { name: "series-results", ty: Id(Id { idx: 7 }), docs: Docs { contents: None } }, Field { name: "cluster-band", ty: Id(Id { idx: 8 }), docs: Docs { contents: None } }] }, name: "output", ty: Id { idx: 9 } } + var ptr54x2 *Band + if ok53 { + ptr54x2 = &result53 + } else { + ptr54x2 = nil + } + value54 := Output{ + OutlyingSeries: result19, + SeriesResults: result40, + ClusterBand: ptr54x2, + } + value58 = value54 + case 1: + // PointerLoad { offset: ptrsz } + ptr55, ok55 := i.module.Memory().ReadUint32Le(uint32(results13 + 4)) + if !ok55 { + var default55 Output + return default55, errors.New("failed to read pointer from memory") + } + // LengthLoad { offset: (2*ptrsz) } + len56, ok56 := i.module.Memory().ReadUint32Le(uint32(results13 + 8)) + if !ok56 { + var default56 Output + return default56, errors.New("failed to read length from memory") + } + // StringLift + buf57, ok57 := i.module.Memory().Read(ptr55, len56) + if !ok57 { + var default57 Output + return default57, errors.New("failed to read bytes from memory") + } + str57 := string(buf57) + err58 = errors.New(str57) + default: + err58 = errors.New("invalid variant discriminant for expected") + } + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "detect", kind: Freestanding, params: [("input", Id(Id { idx: 18 }))], result: Some(Id(Id { idx: 20 })), docs: Docs { contents: None }, stability: Unknown } } + return value58, err58 +} + diff --git a/cmd/gravity/tests/cmd/outlier.toml b/cmd/gravity/tests/cmd/outlier.toml new file mode 100644 index 0000000..61b365e --- /dev/null +++ b/cmd/gravity/tests/cmd/outlier.toml @@ -0,0 +1,2 @@ +bin.name = "gravity" +args = "--world outlier ../../target/wasm32-unknown-unknown/release/example_outlier.wasm" diff --git a/cmd/gravity/tests/cmd/records.stderr b/cmd/gravity/tests/cmd/records.stderr new file mode 100644 index 0000000..e69de29 diff --git a/cmd/gravity/tests/cmd/records.stdout b/cmd/gravity/tests/cmd/records.stdout new file mode 100644 index 0000000..3a321e7 --- /dev/null +++ b/cmd/gravity/tests/cmd/records.stdout @@ -0,0 +1,322 @@ +// Code generated by arcjet-gravity; DO NOT EDIT. + +package records + +import "context" +import "errors" +import "github.com/tetratelabs/wazero" +import "github.com/tetratelabs/wazero/api" + +import _ "embed" + +//go:embed records.wasm +var wasmFileRecords []byte + +type IRecordsTypes interface {} + +type Foo struct { + Float32 float32 + + Float64 float64 + + Uint32 uint32 + + Uint64 uint64 + + S string + + Vf32 []float32 + + Vf64 []float64 +} + +type RecordsFactory struct { + runtime wazero.Runtime + module wazero.CompiledModule +} + +func NewRecordsFactory( + ctx context.Context, + types IRecordsTypes, +) (*RecordsFactory, error) { + wazeroRuntime := wazero.NewRuntime(ctx) + + // Instantiate import host modules + _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:records/types"). + Instantiate(ctx) + if err0 != nil { + return nil, err0 + } + + // Compiling the module takes a LONG time, so we want to do it once and hold + // onto it with the Runtime + module, err := wazeroRuntime.CompileModule(ctx, wasmFileRecords) + if err != nil { + return nil, err + } + return &RecordsFactory{ + runtime: wazeroRuntime, + module: module, + }, nil +} + +func (f *RecordsFactory) Instantiate(ctx context.Context) (*RecordsInstance, error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, wazero.NewModuleConfig()); err != nil { + return nil, err + } else { + return &RecordsInstance{module}, nil + } +} + +func (f *RecordsFactory) Close(ctx context.Context) { + f.runtime.Close(ctx) +} + +type RecordsInstance struct { + module api.Module +} + +func (i *RecordsInstance) Close(ctx context.Context) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil +} + +// writeString will put a Go string into the Wasm memory following the Component +// Model calling convetions, such as allocating memory with the realloc function +func writeString( + ctx context.Context, + s string, + memory api.Memory, + realloc api.Function, +) (uint64, uint64, error) { + if len(s) == 0 { + return 1, 0, nil + } + + results, err := realloc.Call(ctx, 0, 0, 1, uint64(len(s))) + if err != nil { + return 1, 0, err + } + ptr := results[0] + ok := memory.Write(uint32(ptr), []byte(s)) + if !ok { + return 1, 0, err + } + return uint64(ptr), uint64(len(s)), nil +} + +func (i *RecordsInstance) ModifyFoo( + ctx context.Context, + f Foo, +) Foo { + arg0 := f + // GetArg { nth: 0 } + // RecordLower { record: Record { fields: [Field { name: "float32", ty: F32, docs: Docs { contents: None } }, Field { name: "float64", ty: F64, docs: Docs { contents: None } }, Field { name: "uint32", ty: U32, docs: Docs { contents: None } }, Field { name: "uint64", ty: U64, docs: Docs { contents: None } }, Field { name: "s", ty: String, docs: Docs { contents: None } }, Field { name: "vf32", ty: Id(Id { idx: 0 }), docs: Docs { contents: None } }, Field { name: "vf64", ty: Id(Id { idx: 1 }), docs: Docs { contents: None } }] }, name: "foo", ty: Id { idx: 2 } } + float320 := arg0.Float32 + float640 := arg0.Float64 + uint320 := arg0.Uint32 + uint640 := arg0.Uint64 + s0 := arg0.S + vf320 := arg0.Vf32 + vf640 := arg0.Vf64 + // CoreF32FromF32 + result1 := api.EncodeF32(float32(float320)) + // CoreF64FromF64 + result2 := api.EncodeF64(float64(float640)) + // I32FromU32 + result3 := api.EncodeU32(uint320) + // I64FromU64 + value4 := int64(uint640) + // StringLower { realloc: Some("cabi_realloc") } + memory5 := i.module.Memory() + realloc5 := i.module.ExportedFunction("cabi_realloc") + ptr5, len5, err5 := writeString(ctx, s0, memory5, realloc5) + // The return type doesn't contain an error so we panic if one is encountered + if err5 != nil { + panic(err5) + } + // ListLower { element: F32, realloc: Some("cabi_realloc") } + vec7 := vf320 + len7 := uint64(len(vec7)) + result7, err7 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 4, len7 * 4) + // The return type doesn't contain an error so we panic if one is encountered + if err7 != nil { + panic(err7) + } + ptr7 := result7[0] + for idx := uint64(0); idx < len7; idx++ { + e := vec7[idx] + base := uint32(ptr7 + uint64(idx) * uint64(4)) + // IterElem { element: F32 } + // IterBasePointer + // CoreF32FromF32 + result6 := api.EncodeF32(float32(e)) + // F32Store { offset: 0 } + i.module.Memory().WriteUint64Le(uint32(base+0), result6) + } + // ListLower { element: F64, realloc: Some("cabi_realloc") } + vec9 := vf640 + len9 := uint64(len(vec9)) + result9, err9 := i.module.ExportedFunction("cabi_realloc").Call(ctx, 0, 0, 8, len9 * 8) + // The return type doesn't contain an error so we panic if one is encountered + if err9 != nil { + panic(err9) + } + ptr9 := result9[0] + for idx := uint64(0); idx < len9; idx++ { + e := vec9[idx] + base := uint32(ptr9 + uint64(idx) * uint64(8)) + // IterElem { element: F64 } + // IterBasePointer + // CoreF64FromF64 + result8 := api.EncodeF64(float64(e)) + // F64Store { offset: 0 } + i.module.Memory().WriteUint64Le(uint32(base+0), result8) + } + // CallWasm { name: "modify-foo", sig: WasmSignature { params: [F32, F64, I32, I64, Pointer, Length, Pointer, Length, Pointer, Length], results: [Pointer], indirect_params: false, retptr: true } } + raw10, err10 := i.module.ExportedFunction("modify-foo").Call(ctx, uint64(result1), uint64(result2), uint64(result3), uint64(value4), uint64(ptr5), uint64(len5), uint64(ptr7), uint64(len7), uint64(ptr9), uint64(len9)) + // The return type doesn't contain an error so we panic if one is encountered + if err10 != nil { + panic(err10) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if _, err := i.module.ExportedFunction("cabi_post_modify-foo").Call(ctx, raw10...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + }() + + results10 := raw10[0] + // F32Load { offset: 0 } + value11, ok11 := i.module.Memory().ReadUint64Le(uint32(results10 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok11 { + panic(errors.New("failed to read f64 from memory")) + } + // F32FromCoreF32 + result12 := api.DecodeF32(value11) + // F64Load { offset: 8 } + value13, ok13 := i.module.Memory().ReadUint64Le(uint32(results10 + 8)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok13 { + panic(errors.New("failed to read f64 from memory")) + } + // F64FromCoreF64 + result14 := api.DecodeF64(value13) + // I32Load { offset: 16 } + value15, ok15 := i.module.Memory().ReadUint32Le(uint32(results10 + 16)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok15 { + panic(errors.New("failed to read i32 from memory")) + } + // U32FromI32 + result16 := api.DecodeU32(uint64(value15)) + // I64Load { offset: 24 } + value17, ok17 := i.module.Memory().ReadUint64Le(uint32(results10 + 24)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok17 { + panic(errors.New("failed to read i64 from memory")) + } + // U64FromI64 + value18 := uint64(value17) + // PointerLoad { offset: 32 } + ptr19, ok19 := i.module.Memory().ReadUint32Le(uint32(results10 + 32)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok19 { + panic(errors.New("failed to read pointer from memory")) + } + // LengthLoad { offset: (32+1*ptrsz) } + len20, ok20 := i.module.Memory().ReadUint32Le(uint32(results10 + 36)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok20 { + panic(errors.New("failed to read length from memory")) + } + // StringLift + buf21, ok21 := i.module.Memory().Read(ptr19, len20) + // The return type doesn't contain an error so we panic if one is encountered + if !ok21 { + panic(errors.New("failed to read bytes from memory")) + } + str21 := string(buf21) + // PointerLoad { offset: (32+2*ptrsz) } + ptr22, ok22 := i.module.Memory().ReadUint32Le(uint32(results10 + 40)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok22 { + panic(errors.New("failed to read pointer from memory")) + } + // LengthLoad { offset: (32+3*ptrsz) } + len23, ok23 := i.module.Memory().ReadUint32Le(uint32(results10 + 44)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok23 { + panic(errors.New("failed to read length from memory")) + } + // ListLift { element: F32, ty: Id { idx: 0 } } + base26 := ptr22 + len26 := len23 + result26 := make([]float32, len26) + for idx26 := uint32(0); idx26 < len26; idx26++ { + base := base26 + idx26 * 4 + // IterBasePointer + // F32Load { offset: 0 } + value24, ok24 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok24 { + panic(errors.New("failed to read f64 from memory")) + } + // F32FromCoreF32 + result25 := api.DecodeF32(value24) + result26[idx26] = result25 + } + // PointerLoad { offset: (32+4*ptrsz) } + ptr27, ok27 := i.module.Memory().ReadUint32Le(uint32(results10 + 48)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok27 { + panic(errors.New("failed to read pointer from memory")) + } + // LengthLoad { offset: (32+5*ptrsz) } + len28, ok28 := i.module.Memory().ReadUint32Le(uint32(results10 + 52)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok28 { + panic(errors.New("failed to read length from memory")) + } + // ListLift { element: F64, ty: Id { idx: 1 } } + base31 := ptr27 + len31 := len28 + result31 := make([]float64, len31) + for idx31 := uint32(0); idx31 < len31; idx31++ { + base := base31 + idx31 * 8 + // IterBasePointer + // F64Load { offset: 0 } + value29, ok29 := i.module.Memory().ReadUint64Le(uint32(base + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok29 { + panic(errors.New("failed to read f64 from memory")) + } + // F64FromCoreF64 + result30 := api.DecodeF64(value29) + result31[idx31] = result30 + } + // RecordLift { record: Record { fields: [Field { name: "float32", ty: F32, docs: Docs { contents: None } }, Field { name: "float64", ty: F64, docs: Docs { contents: None } }, Field { name: "uint32", ty: U32, docs: Docs { contents: None } }, Field { name: "uint64", ty: U64, docs: Docs { contents: None } }, Field { name: "s", ty: String, docs: Docs { contents: None } }, Field { name: "vf32", ty: Id(Id { idx: 0 }), docs: Docs { contents: None } }, Field { name: "vf64", ty: Id(Id { idx: 1 }), docs: Docs { contents: None } }] }, name: "foo", ty: Id { idx: 2 } } + value32 := Foo{ + Float32: result12, + Float64: result14, + Uint32: result16, + Uint64: value18, + S: str21, + Vf32: result26, + Vf64: result31, + } + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "modify-foo", kind: Freestanding, params: [("f", Id(Id { idx: 3 }))], result: Some(Id(Id { idx: 3 })), docs: Docs { contents: None }, stability: Unknown } } + return value32 +} + diff --git a/cmd/gravity/tests/cmd/records.toml b/cmd/gravity/tests/cmd/records.toml new file mode 100644 index 0000000..4cba21d --- /dev/null +++ b/cmd/gravity/tests/cmd/records.toml @@ -0,0 +1,2 @@ +bin.name = "gravity" +args = "--world records ../../target/wasm32-unknown-unknown/release/example_records.wasm" diff --git a/cmd/gravity/tests/cmd/resources-simple.stderr b/cmd/gravity/tests/cmd/resources-simple.stderr new file mode 100644 index 0000000..e69de29 diff --git a/cmd/gravity/tests/cmd/resources-simple.stdout b/cmd/gravity/tests/cmd/resources-simple.stdout new file mode 100644 index 0000000..0a276be --- /dev/null +++ b/cmd/gravity/tests/cmd/resources-simple.stdout @@ -0,0 +1,357 @@ +// Code generated by arcjet-gravity; DO NOT EDIT. + +package resources + +import "context" +import "errors" +import "github.com/tetratelabs/wazero" +import "github.com/tetratelabs/wazero/api" +import "sync" + +import _ "embed" + +//go:embed resources.wasm +var wasmFileResources []byte + +// ifaceFooerHandle is a handle to the fooer resource in the iface interface. +type ifaceFooerHandle uint32 + +type IfaceFooer interface { + GetX(ctx context.Context) uint32 + SetX(ctx context.Context, x uint32) + GetY(ctx context.Context) string + SetY(ctx context.Context, y string) +} + +type IResourcesIface[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]] interface { + NewFooer(ctx context.Context, x uint32, y string) TIfaceFooerValue +} + +// PIfaceFooer constrains a pointer to a type implementing the IfaceFooer interface. +type PIfaceFooer[TIfaceFooerValue any] interface { + *TIfaceFooerValue + IfaceFooer +} + +// ifaceFooerResourceTable is a resource table for fooer resources from the iface interface. +type ifaceFooerResourceTable[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]] struct { + mu sync.Mutex + nextHandle uint32 + table map[ifaceFooerHandle]*TIfaceFooerValue +} + +func newIfaceFooerResourceTable[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]]() *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer] { + return &ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]{ + nextHandle: 1, + table: make(map[ifaceFooerHandle]*TIfaceFooerValue), + } +} + +// Store adds a resource to the table and returns its handle. +func (t *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]) Store(resource TIfaceFooerValue) ifaceFooerHandle { + t.mu.Lock() + defer t.mu.Unlock() + handle := ifaceFooerHandle(t.nextHandle) + t.nextHandle++ + t.table[handle] = &resource + return handle +} + +// get returns a pointer to the resource from the table by its handle. +func (t *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]) get(handle ifaceFooerHandle) (PTIfaceFooer, bool) { + t.mu.Lock() + defer t.mu.Unlock() + resource, ok := t.table[handle] + if !ok { + var zero PTIfaceFooer + return zero, false + } + return resource, true +} + +// Get retrieves a resource from the table by its handle. +func (t *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]) Get(handle ifaceFooerHandle) (TIfaceFooerValue, bool) { + t.mu.Lock() + defer t.mu.Unlock() + resource, ok := t.table[handle] + if !ok { + var zero TIfaceFooerValue + return zero, false + } + return *resource, true +} + +// Remove deletes a resource from the table. +func (t *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]) Remove(handle ifaceFooerHandle) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.table, handle) +} + +type ResourcesFactory[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]] struct { + runtime wazero.Runtime + module wazero.CompiledModule + IfaceFooerResourceTable *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer] +} + +func NewResourcesFactory[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]]( + ctx context.Context, + iface IResourcesIface[TIfaceFooerValue, PTIfaceFooer], +) (*ResourcesFactory[TIfaceFooerValue, PTIfaceFooer], error) { + wazeroRuntime := wazero.NewRuntime(ctx) + // Initialize resource tables before host module instantiation + ifaceFooerResourceTable := newIfaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]() + // Instantiate import host modules + _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:resources/iface"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + arg2 uint32, + ) uint32{ + // GetArg { nth: 0 } + // U32FromI32 + result0 := api.DecodeU32(uint64(arg0)) + // GetArg { nth: 1 } + // GetArg { nth: 2 } + // StringLift + buf1, ok1 := mod.Memory().Read(arg1, arg2) + if !ok1 { + panic(errors.New("failed to read bytes from memory")) + } + str1 := string(buf1) + // CallInterface { func: Function { name: "[constructor]fooer", kind: Constructor(Id { idx: 0 }), params: [("x", U32), ("y", String)], result: Some(Id(Id { idx: 1 })), docs: Docs { contents: None }, stability: Unknown }, async_: false } + value2 := iface.NewFooer(ctx, result0, str1) + // HandleLower { handle: Own(Id { idx: 0 }), name: "fooer", ty: Id { idx: 1 } } + converted3 := uint32(ifaceFooerResourceTable.Store(value2)) + // Return { amt: 1, func: Function { name: "[constructor]fooer", kind: Constructor(Id { idx: 0 }), params: [("x", U32), ("y", String)], result: Some(Id(Id { idx: 1 })), docs: Docs { contents: None }, stability: Unknown } } + return converted3 + }). + Export("[constructor]fooer"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + ) uint32{ + // GetArg { nth: 0 } + // HandleLift { handle: Borrow(Id { idx: 0 }), name: "fooer", ty: Id { idx: 2 } } + converted0 := ifaceFooerHandle(arg0) + // CallInterface { func: Function { name: "[method]fooer.get-x", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 }))], result: Some(U32), docs: Docs { contents: None }, stability: Unknown }, async_: false } + resource1, ok1 := ifaceFooerResourceTable.get(converted0) + if !ok1 { + panic("invalid resource handle") + } + value1 := resource1.GetX(ctx) + // I32FromU32 + result2 := value1 + // Return { amt: 1, func: Function { name: "[method]fooer.get-x", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 }))], result: Some(U32), docs: Docs { contents: None }, stability: Unknown } } + return result2 + }). + Export("[method]fooer.get-x"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + ) { + // GetArg { nth: 0 } + // HandleLift { handle: Borrow(Id { idx: 0 }), name: "fooer", ty: Id { idx: 2 } } + converted0 := ifaceFooerHandle(arg0) + // GetArg { nth: 1 } + // U32FromI32 + result1 := api.DecodeU32(uint64(arg1)) + // CallInterface { func: Function { name: "[method]fooer.set-x", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 })), ("x", U32)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } + resource2, ok2 := ifaceFooerResourceTable.get(converted0) + if !ok2 { + panic("invalid resource handle") + } + resource2.SetX(ctx, result1) + // Return { amt: 0, func: Function { name: "[method]fooer.set-x", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 })), ("x", U32)], result: None, docs: Docs { contents: None }, stability: Unknown } } + }). + Export("[method]fooer.set-x"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + ) { + // GetArg { nth: 0 } + // HandleLift { handle: Borrow(Id { idx: 0 }), name: "fooer", ty: Id { idx: 2 } } + converted0 := ifaceFooerHandle(arg0) + // CallInterface { func: Function { name: "[method]fooer.get-y", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 }))], result: Some(String), docs: Docs { contents: None }, stability: Unknown }, async_: false } + resource1, ok1 := ifaceFooerResourceTable.get(converted0) + if !ok1 { + panic("invalid resource handle") + } + value1 := resource1.GetY(ctx) + // GetArg { nth: 1 } + // StringLower { realloc: Some("cabi_realloc") } + memory2 := mod.Memory() + realloc2 := mod.ExportedFunction("cabi_realloc") + ptr2, len2, err2 := writeString(ctx, value1, memory2, realloc2) + if err2 != nil { + panic(err2) + } + // LengthStore { offset: ptrsz } + mod.Memory().WriteUint32Le(uint32(arg1+4), uint32(len2)) + // PointerStore { offset: 0 } + mod.Memory().WriteUint32Le(uint32(arg1+0), uint32(ptr2)) + // Return { amt: 0, func: Function { name: "[method]fooer.get-y", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 }))], result: Some(String), docs: Docs { contents: None }, stability: Unknown } } + }). + Export("[method]fooer.get-y"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + arg2 uint32, + ) { + // GetArg { nth: 0 } + // HandleLift { handle: Borrow(Id { idx: 0 }), name: "fooer", ty: Id { idx: 2 } } + converted0 := ifaceFooerHandle(arg0) + // GetArg { nth: 1 } + // GetArg { nth: 2 } + // StringLift + buf1, ok1 := mod.Memory().Read(arg1, arg2) + if !ok1 { + panic(errors.New("failed to read bytes from memory")) + } + str1 := string(buf1) + // CallInterface { func: Function { name: "[method]fooer.set-y", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 })), ("y", String)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } + resource2, ok2 := ifaceFooerResourceTable.get(converted0) + if !ok2 { + panic("invalid resource handle") + } + resource2.SetY(ctx, str1) + // Return { amt: 0, func: Function { name: "[method]fooer.set-y", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 })), ("y", String)], result: None, docs: Docs { contents: None }, stability: Unknown } } + }). + Export("[method]fooer.set-y"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, mod api.Module, arg0 uint32) { + // [resource-drop]: called when guest drops a resource + // + // With borrow-only parameters, guests never take ownership of host resources. + // Resources stay in host table until host explicitly removes them. + // This callback is a no-op since host controls the full lifecycle. + // + // Note: If we add owned parameter support in the future, this would need + // to implement ref-counting and state tracking to properly cleanup consumed resources. + _ = arg0 + }). + Export("[resource-drop]fooer"). + Instantiate(ctx) + if err0 != nil { + return nil, err0 + } + // Instantiate export resource management host modules + // Compiling the module takes a LONG time, so we want to do it once and hold + // onto it with the Runtime + module, err := wazeroRuntime.CompileModule(ctx, wasmFileResources) + if err != nil { + return nil, err + } + return &ResourcesFactory[TIfaceFooerValue, PTIfaceFooer]{ + runtime: wazeroRuntime, + module: module, + IfaceFooerResourceTable: ifaceFooerResourceTable, + }, nil +} + +func (f *ResourcesFactory[TIfaceFooerValue, PTIfaceFooer]) Instantiate(ctx context.Context) (*ResourcesInstance[TIfaceFooerValue, PTIfaceFooer], error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, wazero.NewModuleConfig()); err != nil { + return nil, err + } else { + return &ResourcesInstance[TIfaceFooerValue, PTIfaceFooer]{ + module: module, + IfaceFooerResourceTable: f.IfaceFooerResourceTable, + }, nil + } +} + +func (f *ResourcesFactory[TIfaceFooerValue, PTIfaceFooer]) Close(ctx context.Context) { + f.runtime.Close(ctx) +} + +type ResourcesInstance[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]] struct { + module api.Module + IfaceFooerResourceTable *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer] +} + +func (i *ResourcesInstance[TIfaceFooerValue, PTIfaceFooer]) Close(ctx context.Context) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil +} + +// writeString will put a Go string into the Wasm memory following the Component +// Model calling convetions, such as allocating memory with the realloc function +func writeString( + ctx context.Context, + s string, + memory api.Memory, + realloc api.Function, +) (uint64, uint64, error) { + if len(s) == 0 { + return 1, 0, nil + } + + results, err := realloc.Call(ctx, 0, 0, 1, uint64(len(s))) + if err != nil { + return 1, 0, err + } + ptr := results[0] + ok := memory.Write(uint32(ptr), []byte(s)) + if !ok { + return 1, 0, err + } + return uint64(ptr), uint64(len(s)), nil +} + +func (i *ResourcesInstance[TIfaceFooerValue, PTIfaceFooer]) UseFooer( + ctx context.Context, + foo ifaceFooerHandle, +) { + arg0 := foo + // GetArg { nth: 0 } + // HandleLower { handle: Borrow(Id { idx: 3 }), name: "fooer", ty: Id { idx: 4 } } + converted0 := uint32(arg0) + // CallWasm { name: "use-fooer", sig: WasmSignature { params: [I32], results: [], indirect_params: false, retptr: false } } + _, err1 := i.module.ExportedFunction("use-fooer").Call(ctx, uint64(converted0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + // Return { amt: 0, func: Function { name: "use-fooer", kind: Freestanding, params: [("foo", Id(Id { idx: 4 }))], result: None, docs: Docs { contents: None }, stability: Unknown } } +} + +func (i *ResourcesInstance[TIfaceFooerValue, PTIfaceFooer]) UseFooerReturnNew( + ctx context.Context, + foo ifaceFooerHandle, +) ifaceFooerHandle { + arg0 := foo + // GetArg { nth: 0 } + // HandleLower { handle: Borrow(Id { idx: 3 }), name: "fooer", ty: Id { idx: 4 } } + converted0 := uint32(arg0) + // CallWasm { name: "use-fooer-return-new", sig: WasmSignature { params: [I32], results: [I32], indirect_params: false, retptr: false } } + raw1, err1 := i.module.ExportedFunction("use-fooer-return-new").Call(ctx, uint64(converted0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + results1 := raw1[0] + // HandleLift { handle: Own(Id { idx: 3 }), name: "fooer", ty: Id { idx: 5 } } + converted2 := ifaceFooerHandle(results1) + // Return { amt: 1, func: Function { name: "use-fooer-return-new", kind: Freestanding, params: [("foo", Id(Id { idx: 4 }))], result: Some(Id(Id { idx: 5 })), docs: Docs { contents: None }, stability: Unknown } } + return converted2 +} + diff --git a/cmd/gravity/tests/cmd/resources-simple.toml b/cmd/gravity/tests/cmd/resources-simple.toml new file mode 100644 index 0000000..2297466 --- /dev/null +++ b/cmd/gravity/tests/cmd/resources-simple.toml @@ -0,0 +1,2 @@ +bin.name = "gravity" +args = "--world resources ../../target/wasm32-unknown-unknown/release/example_resources_simple.wasm" diff --git a/cmd/gravity/tests/cmd/tuples.stderr b/cmd/gravity/tests/cmd/tuples.stderr new file mode 100644 index 0000000..e69de29 diff --git a/cmd/gravity/tests/cmd/tuples.stdout b/cmd/gravity/tests/cmd/tuples.stdout new file mode 100644 index 0000000..56a4c76 --- /dev/null +++ b/cmd/gravity/tests/cmd/tuples.stdout @@ -0,0 +1,276 @@ +// Code generated by arcjet-gravity; DO NOT EDIT. + +package tuples + +import "context" +import "errors" +import "github.com/tetratelabs/wazero" +import "github.com/tetratelabs/wazero/api" + +import _ "embed" + +//go:embed tuples.wasm +var wasmFileTuples []byte + +type ITuplesTypes interface {} + +type CustomTuple struct { + F0 uint32 + + F1 float64 + + F2 string +} + +type TuplesFactory struct { + runtime wazero.Runtime + module wazero.CompiledModule +} + +func NewTuplesFactory( + ctx context.Context, + types ITuplesTypes, +) (*TuplesFactory, error) { + wazeroRuntime := wazero.NewRuntime(ctx) + + // Instantiate import host modules + _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:tuples/types"). + Instantiate(ctx) + if err0 != nil { + return nil, err0 + } + + // Compiling the module takes a LONG time, so we want to do it once and hold + // onto it with the Runtime + module, err := wazeroRuntime.CompileModule(ctx, wasmFileTuples) + if err != nil { + return nil, err + } + return &TuplesFactory{ + runtime: wazeroRuntime, + module: module, + }, nil +} + +func (f *TuplesFactory) Instantiate(ctx context.Context) (*TuplesInstance, error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, wazero.NewModuleConfig()); err != nil { + return nil, err + } else { + return &TuplesInstance{module}, nil + } +} + +func (f *TuplesFactory) Close(ctx context.Context) { + f.runtime.Close(ctx) +} + +type TuplesInstance struct { + module api.Module +} + +func (i *TuplesInstance) Close(ctx context.Context) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil +} + +// writeString will put a Go string into the Wasm memory following the Component +// Model calling convetions, such as allocating memory with the realloc function +func writeString( + ctx context.Context, + s string, + memory api.Memory, + realloc api.Function, +) (uint64, uint64, error) { + if len(s) == 0 { + return 1, 0, nil + } + + results, err := realloc.Call(ctx, 0, 0, 1, uint64(len(s))) + if err != nil { + return 1, 0, err + } + ptr := results[0] + ok := memory.Write(uint32(ptr), []byte(s)) + if !ok { + return 1, 0, err + } + return uint64(ptr), uint64(len(s)), nil +} + +func (i *TuplesInstance) CustomTupleFunc( + ctx context.Context, + t CustomTuple, +) CustomTuple { + arg0 := t + // GetArg { nth: 0 } + // TupleLower { tuple: Tuple { types: [U32, F64, String] }, ty: Id { idx: 0 } } + f00 := arg0.F0 + f01 := arg0.F1 + f02 := arg0.F2 + // I32FromU32 + result1 := api.EncodeU32(f00) + // CoreF64FromF64 + result2 := api.EncodeF64(float64(f01)) + // StringLower { realloc: Some("cabi_realloc") } + memory3 := i.module.Memory() + realloc3 := i.module.ExportedFunction("cabi_realloc") + ptr3, len3, err3 := writeString(ctx, f02, memory3, realloc3) + // The return type doesn't contain an error so we panic if one is encountered + if err3 != nil { + panic(err3) + } + // CallWasm { name: "custom-tuple-func", sig: WasmSignature { params: [I32, F64, Pointer, Length], results: [Pointer], indirect_params: false, retptr: true } } + raw4, err4 := i.module.ExportedFunction("custom-tuple-func").Call(ctx, uint64(result1), uint64(result2), uint64(ptr3), uint64(len3)) + // The return type doesn't contain an error so we panic if one is encountered + if err4 != nil { + panic(err4) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if _, err := i.module.ExportedFunction("cabi_post_custom-tuple-func").Call(ctx, raw4...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + }() + + results4 := raw4[0] + // I32Load { offset: 0 } + value5, ok5 := i.module.Memory().ReadUint32Le(uint32(results4 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok5 { + panic(errors.New("failed to read i32 from memory")) + } + // U32FromI32 + result6 := api.DecodeU32(uint64(value5)) + // F64Load { offset: 8 } + value7, ok7 := i.module.Memory().ReadUint64Le(uint32(results4 + 8)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok7 { + panic(errors.New("failed to read f64 from memory")) + } + // F64FromCoreF64 + result8 := api.DecodeF64(value7) + // PointerLoad { offset: 16 } + ptr9, ok9 := i.module.Memory().ReadUint32Le(uint32(results4 + 16)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok9 { + panic(errors.New("failed to read pointer from memory")) + } + // LengthLoad { offset: (16+1*ptrsz) } + len10, ok10 := i.module.Memory().ReadUint32Le(uint32(results4 + 20)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok10 { + panic(errors.New("failed to read length from memory")) + } + // StringLift + buf11, ok11 := i.module.Memory().Read(ptr9, len10) + // The return type doesn't contain an error so we panic if one is encountered + if !ok11 { + panic(errors.New("failed to read bytes from memory")) + } + str11 := string(buf11) + // TupleLift { tuple: Tuple { types: [U32, F64, String] }, ty: Id { idx: 0 } } + var value12 CustomTuple + value12.F0 = result6 + value12.F1 = result8 + value12.F2 = str11 + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "custom-tuple-func", kind: Freestanding, params: [("t", Id(Id { idx: 1 }))], result: Some(Id(Id { idx: 1 })), docs: Docs { contents: None }, stability: Unknown } } + return value12 +} + +func (i *TuplesInstance) AnonymousTupleFunc( + ctx context.Context, + t struct{F0 uint32; F1 float64; F2 string}, +) struct{F0 uint32; F1 float64; F2 string} { + arg0 := t + // GetArg { nth: 0 } + // TupleLower { tuple: Tuple { types: [U32, F64, String] }, ty: Id { idx: 2 } } + f00 := arg0.F0 + f01 := arg0.F1 + f02 := arg0.F2 + // I32FromU32 + result1 := api.EncodeU32(f00) + // CoreF64FromF64 + result2 := api.EncodeF64(float64(f01)) + // StringLower { realloc: Some("cabi_realloc") } + memory3 := i.module.Memory() + realloc3 := i.module.ExportedFunction("cabi_realloc") + ptr3, len3, err3 := writeString(ctx, f02, memory3, realloc3) + // The return type doesn't contain an error so we panic if one is encountered + if err3 != nil { + panic(err3) + } + // CallWasm { name: "anonymous-tuple-func", sig: WasmSignature { params: [I32, F64, Pointer, Length], results: [Pointer], indirect_params: false, retptr: true } } + raw4, err4 := i.module.ExportedFunction("anonymous-tuple-func").Call(ctx, uint64(result1), uint64(result2), uint64(ptr3), uint64(len3)) + // The return type doesn't contain an error so we panic if one is encountered + if err4 != nil { + panic(err4) + } + + // The cleanup via `cabi_post_*` cleans up the memory in the guest. By + // deferring this, we ensure that no memory is corrupted before the function + // is done accessing it. + defer func() { + if _, err := i.module.ExportedFunction("cabi_post_anonymous-tuple-func").Call(ctx, raw4...); err != nil { + // If we get an error during cleanup, something really bad is + // going on, so we panic. Also, you can't return the error from + // the `defer` + panic(errors.New("failed to cleanup")) + } + }() + + results4 := raw4[0] + // I32Load { offset: 0 } + value5, ok5 := i.module.Memory().ReadUint32Le(uint32(results4 + 0)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok5 { + panic(errors.New("failed to read i32 from memory")) + } + // U32FromI32 + result6 := api.DecodeU32(uint64(value5)) + // F64Load { offset: 8 } + value7, ok7 := i.module.Memory().ReadUint64Le(uint32(results4 + 8)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok7 { + panic(errors.New("failed to read f64 from memory")) + } + // F64FromCoreF64 + result8 := api.DecodeF64(value7) + // PointerLoad { offset: 16 } + ptr9, ok9 := i.module.Memory().ReadUint32Le(uint32(results4 + 16)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok9 { + panic(errors.New("failed to read pointer from memory")) + } + // LengthLoad { offset: (16+1*ptrsz) } + len10, ok10 := i.module.Memory().ReadUint32Le(uint32(results4 + 20)) + // The return type doesn't contain an error so we panic if one is encountered + if !ok10 { + panic(errors.New("failed to read length from memory")) + } + // StringLift + buf11, ok11 := i.module.Memory().Read(ptr9, len10) + // The return type doesn't contain an error so we panic if one is encountered + if !ok11 { + panic(errors.New("failed to read bytes from memory")) + } + str11 := string(buf11) + // TupleLift { tuple: Tuple { types: [U32, F64, String] }, ty: Id { idx: 2 } } + var value12 struct{F0 uint32; F1 float64; F2 string} + value12.F0 = result6 + value12.F1 = result8 + value12.F2 = str11 + // Flush { amt: 1 } + // Return { amt: 1, func: Function { name: "anonymous-tuple-func", kind: Freestanding, params: [("t", Id(Id { idx: 2 }))], result: Some(Id(Id { idx: 2 })), docs: Docs { contents: None }, stability: Unknown } } + return value12 +} + diff --git a/cmd/gravity/tests/cmd/tuples.toml b/cmd/gravity/tests/cmd/tuples.toml new file mode 100644 index 0000000..df5c292 --- /dev/null +++ b/cmd/gravity/tests/cmd/tuples.toml @@ -0,0 +1,2 @@ +bin.name = "gravity" +args = "--world tuples ../../target/wasm32-unknown-unknown/release/example_tuples.wasm" diff --git a/examples/basic/basic_test.go b/examples/basic/basic_test.go index ce0f486..e05d5db 100644 --- a/examples/basic/basic_test.go +++ b/examples/basic/basic_test.go @@ -71,7 +71,7 @@ func TestNoOptionalPrimitiveCleanup(t *testing.T) { } defer ins.Close(t.Context()) - actual, ok := ins.OptionalPrimitive(t.Context()) + actual, ok := ins.OptionalPrimitive(t.Context(), true) if !ok { t.Fatal(err) } diff --git a/examples/basic/src/lib.rs b/examples/basic/src/lib.rs index 6ef53fc..bc63b07 100644 --- a/examples/basic/src/lib.rs +++ b/examples/basic/src/lib.rs @@ -17,10 +17,13 @@ impl Guest for BasicWorld { fn primitive() -> bool { true } - fn optional_primitive() -> Option { + fn optional_primitive(_: Option) -> Option { Some(true) } fn result_primitive() -> Result { Ok(true) } + fn optional_string(s: Option) -> Option { + s + } } diff --git a/examples/basic/wit/basic.wit b/examples/basic/wit/basic.wit index 395db03..ba5cebe 100644 --- a/examples/basic/wit/basic.wit +++ b/examples/basic/wit/basic.wit @@ -12,6 +12,8 @@ world basic { export hello: func() -> result; export primitive: func() -> bool; - export optional-primitive: func() -> option; + export optional-primitive: func(b: option) -> option; export result-primitive: func() -> result; + + export optional-string: func(s: option) -> option; } diff --git a/examples/generate.go b/examples/generate.go index 47058c8..149d7bd 100644 --- a/examples/generate.go +++ b/examples/generate.go @@ -1,9 +1,19 @@ package examples //go:generate cargo build -p example-basic --target wasm32-unknown-unknown --release +//go:generate cargo build -p example-records --target wasm32-unknown-unknown --release //go:generate cargo build -p example-iface-method-returns-string --target wasm32-unknown-unknown --release //go:generate cargo build -p example-instructions --target wasm32-unknown-unknown --release +//go:generate sh -c "RUSTFLAGS='--cfg getrandom_backend=\"custom\"' cargo build -q -p example-outlier --target wasm32-unknown-unknown --release" +//go:generate cargo build -p example-resources --target wasm32-unknown-unknown --release +//go:generate cargo build -p example-resources-simple --target wasm32-unknown-unknown --release +//go:generate cargo build -p example-tuples --target wasm32-unknown-unknown --release //go:generate cargo run --bin gravity -- --world basic --output ./basic/basic.go ../target/wasm32-unknown-unknown/release/example_basic.wasm +//go:generate cargo run --bin gravity -- --world records --output ./records/records.go ../target/wasm32-unknown-unknown/release/example_records.wasm //go:generate cargo run --bin gravity -- --world example --output ./iface-method-returns-string/example.go ../target/wasm32-unknown-unknown/release/example_iface_method_returns_string.wasm //go:generate cargo run --bin gravity -- --world instructions --output ./instructions/bindings.go ../target/wasm32-unknown-unknown/release/example_instructions.wasm +//go:generate cargo run --bin gravity -- --world outlier --output ./outlier/outlier.go ../target/wasm32-unknown-unknown/release/example_outlier.wasm +//go:generate cargo run --bin gravity -- --world resources --output ./resources/resources.go ../target/wasm32-unknown-unknown/release/example_resources.wasm +//go:generate cargo run --bin gravity -- --world resources --output ./resources-simple/resources.go ../target/wasm32-unknown-unknown/release/example_resources_simple.wasm +//go:generate cargo run --bin gravity -- --world tuples --output ./tuples/tuples.go ../target/wasm32-unknown-unknown/release/example_tuples.wasm diff --git a/examples/instructions/instructions_test.go b/examples/instructions/instructions_test.go index 5914c7d..5ce3f47 100644 --- a/examples/instructions/instructions_test.go +++ b/examples/instructions/instructions_test.go @@ -1,12 +1,16 @@ package instructions import ( + "fmt" "iter" "math" + "math/rand/v2" "testing" ) -func inclusive[Num interface { ~int8 | ~uint8 | ~int16 | ~uint16 }](start Num, end Num) iter.Seq[Num] { +func inclusive[Num interface { + ~int8 | ~uint8 | ~int16 | ~uint16 +}](start Num, end Num) iter.Seq[Num] { return func(yield func(v Num) bool) { var next Num = start for { @@ -21,7 +25,7 @@ func inclusive[Num interface { ~int8 | ~uint8 | ~int16 | ~uint16 }](start Num, e } } } -func inclusiveStep[Num interface { ~int32 | ~uint32 }](start Num, end Num, step Num) iter.Seq[Num] { +func inclusiveStep[Num interface{ ~int32 | ~uint32 }](start Num, end Num, step Num) iter.Seq[Num] { return func(yield func(v Num) bool) { var next Num = start for { @@ -32,7 +36,7 @@ func inclusiveStep[Num interface { ~int32 | ~uint32 }](start Num, end Num, step return } - if end - step > next { + if end-step > next { next += step } else { next = end @@ -166,3 +170,55 @@ func Test_U32Roundtrip(t *testing.T) { } } } + +func Test_F32Roundtrip(t *testing.T) { + fac, err := NewInstructionsFactory(t.Context()) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + // Generate a bunch of random floats and check they all roundtrip correctly. + seed := 123456 + rng := rand.New(rand.NewPCG(uint64(seed), uint64(seed))) + for i := range 1000 { + t.Run(fmt.Sprintf("i: %d", i), func(t *testing.T) { + expected := rng.Float32() + if actual := ins.F32Roundtrip(t.Context(), expected); actual != expected { + t.Errorf("expected: %f, but got: %f", expected, actual) + } + }) + } +} + +func Test_F64Roundtrip(t *testing.T) { + fac, err := NewInstructionsFactory(t.Context()) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + // Generate a bunch of random floats and check they all roundtrip correctly. + seed := 123456 + rng := rand.New(rand.NewPCG(uint64(seed), uint64(seed))) + for i := range 1000 { + t.Run(fmt.Sprintf("i: %d", i), func(t *testing.T) { + expected := rng.Float64() + if actual := ins.F64Roundtrip(t.Context(), expected); actual != expected { + t.Errorf("expected: %f, but got: %f", expected, actual) + } + }) + } +} diff --git a/examples/instructions/src/lib.rs b/examples/instructions/src/lib.rs index b9d1984..13c8a1e 100644 --- a/examples/instructions/src/lib.rs +++ b/examples/instructions/src/lib.rs @@ -31,4 +31,12 @@ impl Guest for InstructionsWorld { assert!((u32::MIN..=u32::MAX).contains(&val)); val } + fn f32_roundtrip(val: f32) -> f32 { + assert!((f32::MIN..=f32::MAX).contains(&val)); + val + } + fn f64_roundtrip(val: f64) -> f64 { + assert!((f64::MIN..=f64::MAX).contains(&val)); + val + } } diff --git a/examples/instructions/wit/instructions.wit b/examples/instructions/wit/instructions.wit index abfd3ed..84888f8 100644 --- a/examples/instructions/wit/instructions.wit +++ b/examples/instructions/wit/instructions.wit @@ -12,4 +12,8 @@ world instructions { export s32-roundtrip: func(val: s32) -> s32; export u32-roundtrip: func(val: u32) -> u32; + + export f32-roundtrip: func(val: f32) -> f32; + + export f64-roundtrip: func(val: f64) -> f64; } diff --git a/examples/outlier/Cargo.toml b/examples/outlier/Cargo.toml new file mode 100644 index 0000000..af1b814 --- /dev/null +++ b/examples/outlier/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "example-outlier" +version = "0.0.2" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +augurs-outlier = { git = "https://github.com/grafana/augurs", branch = "rand-0.9", features = ["serde"] } +getrandom = { version = "0.3" } +wit-bindgen = "=0.46.0" +wit-component = "=0.239.0" diff --git a/examples/outlier/README.md b/examples/outlier/README.md new file mode 100644 index 0000000..0d32659 --- /dev/null +++ b/examples/outlier/README.md @@ -0,0 +1,40 @@ +# Outlier Detection Wasm Component + +This directory contains: + +- the interface definition for a Wasm component that can perform outlier detection on time series (see `wit/world.wit`) +- a Rust crate which exposes augurs' outlier detection functionality as a Wasm component (see `src/lib.rs`) +- tests for a Go package which can be generated from the Wasm component using [gravity] (see `outlier_test.go`) + +The goal is to provide a Go package that can embed some augurs functionality without having to resort to FFI. + +## Building + +### Prerequisites + +- A recent nightly Rust compiler (e.g. `rustup install nightly`) +- The `wasm32-unknown-unknown` target (`rustup target add wasm32-unknown-unknown`) +- [Go](https://go.dev/) +- [gravity](https://github.com/arcjet/gravity) +- [just](https://just.systems/man/en/) + +### Steps + +1. Build the Rust crate + + ```sh + just build-wasm + ``` + +2. Generate the Go package + + ```sh + just gen-go + ``` + +3. Run the Go tests + + + ```sh + just test-go + ``` diff --git a/examples/outlier/justfile b/examples/outlier/justfile new file mode 100644 index 0000000..961e786 --- /dev/null +++ b/examples/outlier/justfile @@ -0,0 +1,11 @@ +# Build the Rust code. +build-wasm: + RUSTFLAGS='--cfg getrandom_backend="custom"' cargo build --target wasm32-unknown-unknown --release + +# Generate the Go code from the Wasm file using Gravity. +gen-go: + go generate + +# Test the generated Go code. +test-go: + go test ./... diff --git a/examples/outlier/outlier_test.go b/examples/outlier/outlier_test.go new file mode 100644 index 0000000..df4cb9d --- /dev/null +++ b/examples/outlier/outlier_test.go @@ -0,0 +1,81 @@ +package outlier + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type outlierTypes struct{} + +func TestDetect(t *testing.T) { + t.Parallel() + + t.Run("dbscan", func(t *testing.T) { + ctx := t.Context() + f, err := NewOutlierFactory(ctx, outlierTypes{}) + require.NoError(t, err) + t.Cleanup(func() { f.Close(ctx) }) + + ins, err := f.Instantiate(ctx) + require.NoError(t, err) + t.Cleanup(func() { ins.Close(ctx) }) + + input := Input{ + Algorithm: AlgorithmDbscan{ + EpsilonOrSensitivity: EpsilonOrSensitivitySensitivity(0.5), + }, + Data: [][]float64{ + {1.0, 2.0, 1.5, 2.3}, + {1.9, 2.2, 1.2, 2.4}, + {1.5, 2.1, 6.4, 8.5}, + }, + } + + detected, err := ins.Detect(ctx, input) + require.NoError(t, err) + + assert.Len(t, detected.OutlyingSeries, 1) + assert.Contains(t, detected.OutlyingSeries, uint32(2)) + assert.True(t, detected.SeriesResults[2].IsOutlier) + assert.Equal(t, []float64{0.0, 0.0, 1.0, 1.0}, detected.SeriesResults[2].Scores) + assert.NotNil(t, detected.ClusterBand) + }) + + t.Run("mad", func(t *testing.T) { + t.Parallel() + ctx := t.Context() + f, err := NewOutlierFactory(ctx, outlierTypes{}) + require.NoError(t, err) + t.Cleanup(func() { f.Close(ctx) }) + + ins, err := f.Instantiate(ctx) + require.NoError(t, err) + t.Cleanup(func() { ins.Close(ctx) }) + + input := Input{ + Algorithm: AlgorithmMad{ + ThresholdOrSensitivity: ThresholdOrSensitivitySensitivity(0.5), + }, + Data: [][]float64{ + {31.6}, + {33.12}, + {33.84}, + {38.234}, // outlier + {12.83}, + {15.23}, + {33.23}, + {32.85}, + {24.72}, + }, + } + detected, err := ins.Detect(ctx, input) + require.NoError(t, err) + + assert.Len(t, detected.OutlyingSeries, 1) + assert.Contains(t, detected.OutlyingSeries, uint32(3)) + assert.True(t, detected.SeriesResults[3].IsOutlier) + assert.NotNil(t, detected.ClusterBand) + }) +} diff --git a/examples/outlier/src/lib.rs b/examples/outlier/src/lib.rs new file mode 100644 index 0000000..b497f5f --- /dev/null +++ b/examples/outlier/src/lib.rs @@ -0,0 +1,163 @@ +#![doc = include_str!("../README.md")] + +use std::fmt; + +use augurs_outlier::{DbscanDetector, MADDetector, OutlierDetector}; +use getrandom::Error; + +// It looks like some dependency or other is importing getrandom despite not actually +// using it, so we can just provide a dummy implementation here. +// If we see errors later we could try to import a RNG source via the Component Model. +// See https://docs.rs/getrandom/latest/getrandom/#custom-backend for info on this. +#[unsafe(no_mangle)] +unsafe extern "Rust" fn __getrandom_v03_custom(_dest: *mut u8, _len: usize) -> Result<(), Error> { + Err(Error::UNSUPPORTED) +} + +// Wrap the wit-bindgen macro in a module so we don't get warned about missing docs in the generated trait. +mod bindings { + wit_bindgen::generate!({ + world: "outlier", + default_bindings_module: "bindings", + }); +} +use bindings::{ + Guest, + augurs::outlier::types::{ + Algorithm, Band, EpsilonOrSensitivity, Input, OutlierInterval, Output, Series, + ThresholdOrSensitivity, + }, + export, +}; + +struct OutlierWorld; +export!(OutlierWorld); + +impl Guest for OutlierWorld { + fn detect(input: Input) -> Result { + detect(input).map_err(|e| e.to_string()) + } +} + +// #[derive(Debug, Deserialize)] +// #[serde(untagged, rename_all = "camelCase")] +// enum Algorithm { +// Dbscan(DbscanParams), +// Mad(MadParams), +// } + +// #[derive(Debug, Deserialize)] +// #[serde(rename_all = "camelCase")] +// struct DbscanParams { +// epsilon_or_sensitivity: EpsilonOrSensitivity, +// } + +// #[derive(Debug, Deserialize)] +// #[serde(rename_all = "camelCase")] +// enum EpsilonOrSensitivity { +// Sensitivity(f64), +// Epsilon(f64), +// } + +// #[derive(Debug, Deserialize)] +// #[serde(rename_all = "camelCase")] +// struct MadParams { +// threshold_or_sensitivity: ThresholdOrSensitivity, +// } + +// #[derive(Debug, Deserialize)] +// #[serde(rename_all = "camelCase")] +// enum ThresholdOrSensitivity { +// Sensitivity(f64), +// Threshold(f64), +// } + +// #[derive(Debug, Deserialize)] +// #[serde(rename_all = "lowercase")] +// struct Input { +// algorithm: Algorithm, +// data: Vec>, +// } + +#[derive(Debug)] +enum TypedError { + InvalidSensitivity(augurs_outlier::Error), + TryFromIntError(std::num::TryFromIntError), +} + +impl From for TypedError { + fn from(value: augurs_outlier::Error) -> Self { + Self::InvalidSensitivity(value) + } +} + +impl From for TypedError { + fn from(value: std::num::TryFromIntError) -> Self { + Self::TryFromIntError(value) + } +} + +impl fmt::Display for TypedError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::InvalidSensitivity(e) => write!(f, "invalid sensitivity: {}", e), + Self::TryFromIntError(e) => write!(f, "overflow converting to u32: {}", e), + } + } +} + +impl std::error::Error for TypedError {} + +fn detect(input: Input) -> Result { + let data_ref: Vec<_> = input.data.iter().map(|v| v.as_slice()).collect(); + let output = match input.algorithm { + Algorithm::Dbscan(params) => { + let detector = match params.epsilon_or_sensitivity { + EpsilonOrSensitivity::Sensitivity(s) => DbscanDetector::with_sensitivity(s)?, + EpsilonOrSensitivity::Epsilon(e) => DbscanDetector::with_epsilon(e), + }; + let preprocessed = detector.preprocess(&data_ref)?; + detector.detect(&preprocessed)? + } + Algorithm::Mad(params) => { + let detector = match params.threshold_or_sensitivity { + ThresholdOrSensitivity::Sensitivity(s) => MADDetector::with_sensitivity(s)?, + ThresholdOrSensitivity::Threshold(t) => MADDetector::with_threshold(t), + }; + let preprocessed = detector.preprocess(&data_ref)?; + detector.detect(&preprocessed)? + } + }; + Ok(Output { + outlying_series: output + .outlying_series + .into_iter() + .map(TryInto::::try_into) + .collect::>()?, + series_results: output + .series_results + .into_iter() + .map(|series| { + Ok(Series { + is_outlier: series.is_outlier, + outlier_intervals: series + .outlier_intervals + .intervals + .into_iter() + .map(|interval| { + Ok(OutlierInterval { + start: interval.start.try_into()?, + end: interval.end.map(TryInto::try_into).transpose()?, + }) + }) + .collect::>()?, + scores: series.scores, + }) + }) + .collect::>()?, + cluster_band: output.cluster_band.map(|band| Band { + min: band.min, + max: band.max, + }), + }) +} diff --git a/examples/outlier/wit/world.wit b/examples/outlier/wit/world.wit new file mode 100644 index 0000000..e8da5d5 --- /dev/null +++ b/examples/outlier/wit/world.wit @@ -0,0 +1,141 @@ +/// The outlier detector world. +package augurs:outlier; + +/// Types used by the outlier detector world. +interface types { + /// The input for the outlier detector. + /// + /// Currently this is represented as a string because Gravity + /// does not yet support more complex types. + /// + /// It should be a JSON object with the following fields: + record input { + /// The data to detect outliers in. + data: list>, + + /// The algorithm to use, with the params. + algorithm: algorithm, + } + + /// The output of an outlier detector + record output { + /// The indexes of the series considered outliers. + /// + /// This is a `BTreeSet` to ensure that the order of the series is preserved. + outlying-series: list, + + /// The results of the detection for each series. + series-results: list, + + /// The band indicating the min and max value considered outlying + /// at each timestamp. + /// + /// This may be `None` if no cluster was found (for example if + /// there were fewer than 3 series in the input data in the case of + /// DBSCAN). + cluster-band: option, + } + + /// A potentially outlying series. + record series { + /// Whether the series is an outlier for at least one of the samples. + is-outlier: bool, + /// The intervals of the samples that are considered outliers. + outlier-intervals: list, + /// The outlier scores of the series for each sample. + /// + /// The higher the score, the more likely the series is an outlier. + /// Note that some implementations may not provide continuous scores + /// but rather a binary classification. In this case, the scores will + /// be 0.0 for non-outliers and 1.0 for outliers. + scores: list, + } + + /// A single outlier interval. + /// + /// An outlier interval is a contiguous range of indices in a time series + /// where an outlier is detected. + record outlier-interval { + /// The start index of the interval. + start: u32, + /// The end index of the interval, if it exists. + /// + /// If the interval is open-ended, this will be `None`. + end: option, + } + + /// A band indicating the min and max value considered outlying + /// at each timestamp. + record band { + /// The minimum value considered outlying at each timestamp. + min: list, + /// The maximum value considered outlying at each timestamp. + max: list, + } + + /// Errors that can occur during outlier detection. + /// + /// Currently this is represented as a string because Gravity + /// does not yet support more complex types. + type error = string; + + /// The epsilon or sensitivity parameter for the DBSCAN algorithm. + variant epsilon-or-sensitivity { + /// A scale-invariant sensitivity parameter. + /// + /// This must be in (0, 1) and will be used to estimate a sensible + /// value of epsilon based on the data at detection-time. + sensitivity(f64), + /// The maximum distance between points in a cluster. + epsilon(f64), + } + + /// The parameters for the DBSCAN algorithm. + record dbscan-params { + /// Either the epsilon or sensitivity for the algorithm. + epsilon-or-sensitivity: epsilon-or-sensitivity, + } + + /// Either a scale-invariant sensitivity parameter or a threshold. + variant threshold-or-sensitivity { + /// A scale-invariant sensitivity parameter. + /// + /// This must be in (0, 1) and will be used to estimate a sensible + /// threshold at detection-time. + sensitivity(f64), + /// The threshold above which points are considered anomalous. + threshold(f64), + } + + /// The parameters for the MAD algorithm. + record mad-params { + /// Either the threshold or sensitivity for the algorithm. + threshold-or-sensitivity: threshold-or-sensitivity, + } + + /// The algorithm to use for outlier detection. + variant algorithm { + /// The DBSCAN algorithm. + /// + /// This algorithm is a density-based algorithm that uses a + /// clustering algorithm to group together points that are + /// close to each other. + dbscan(dbscan-params), + + /// The MAD algorithm. + /// + /// This algorithm is a density-based algorithm that uses a + /// clustering algorithm to group together points that are + /// close to each other. + mad(mad-params), + } +} + +/// The outlier detector world. +world outlier { + // export types; + use types.{input, output}; + + /// Detect outliers in the input. + export detect: func(input: input) -> result; +} diff --git a/examples/records/Cargo.toml b/examples/records/Cargo.toml new file mode 100644 index 0000000..1ae90a3 --- /dev/null +++ b/examples/records/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example-records" +version = "0.0.2" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = "=0.46.0" +wit-component = "=0.239.0" diff --git a/examples/records/records_test.go b/examples/records/records_test.go new file mode 100644 index 0000000..83e6416 --- /dev/null +++ b/examples/records/records_test.go @@ -0,0 +1,66 @@ +package records + +import ( + "math" + "testing" +) + +type types struct{} + +func TestRecord(t *testing.T) { + tys := types{} + fac, err := NewRecordsFactory(t.Context(), tys) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + foo := Foo{ + Float32: 1.0, + Float64: 1.0, + Uint32: 1, + Uint64: uint64(math.MaxUint32), + S: "hello", + Vf32: []float32{1.0, 2.0, 3.0}, + Vf64: []float64{1.0, 2.0, 3.0}, + } + got := ins.ModifyFoo(t.Context(), foo) + want := Foo{ + Float32: foo.Float32 * 2.0, + Float64: foo.Float64 * 2.0, + Uint32: foo.Uint32 + 1, + Uint64: foo.Uint64 + 1, + S: "received hello", + Vf32: []float32{2.0, 4.0, 6.0}, + Vf64: []float64{2.0, 4.0, 6.0}, + } + if !fooCmp(got, want) { + t.Fatalf("got %+v, want %+v", got, want) + } +} + +func fooCmp(a, b Foo) bool { + if a.Float32 != b.Float32 || a.Float64 != b.Float64 || a.Uint32 != b.Uint32 || a.Uint64 != b.Uint64 || a.S != b.S { + return false + } + if len(a.Vf32) != len(b.Vf32) || len(a.Vf64) != len(b.Vf64) { + return false + } + for i := range a.Vf32 { + if a.Vf32[i] != b.Vf32[i] { + return false + } + } + for i := range a.Vf64 { + if a.Vf64[i] != b.Vf64[i] { + return false + } + } + return true +} diff --git a/examples/records/src/lib.rs b/examples/records/src/lib.rs new file mode 100644 index 0000000..f70d0f5 --- /dev/null +++ b/examples/records/src/lib.rs @@ -0,0 +1,31 @@ +wit_bindgen::generate!({ + world: "records", +}); + +struct RecordsWorld; + +export!(RecordsWorld); + +impl Guest for RecordsWorld { + fn modify_foo( + Foo { + float64, + float32, + uint32, + uint64, + s, + vf32, + vf64, + }: Foo, + ) -> Foo { + Foo { + float64: float64 * 2.0, + float32: float32 * 2.0, + uint32: uint32 + 1, + uint64: uint64 + 1, + s: format!("received {s}"), + vf32: vf32.iter().map(|v| v * 2.0).collect(), + vf64: vf64.iter().map(|v| v * 2.0).collect(), + } + } +} diff --git a/examples/records/wit/records.wit b/examples/records/wit/records.wit new file mode 100644 index 0000000..b467229 --- /dev/null +++ b/examples/records/wit/records.wit @@ -0,0 +1,18 @@ +package arcjet:records; + +interface types { + record foo { + float32: f32, + float64: f64, + uint32: u32, + uint64: u64, + s: string, + vf32: list, + vf64: list, + } +} + +world records { + use types.{foo}; + export modify-foo: func(f: foo) -> foo; +} diff --git a/examples/resources-simple/Cargo.toml b/examples/resources-simple/Cargo.toml new file mode 100644 index 0000000..b9154a7 --- /dev/null +++ b/examples/resources-simple/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example-resources-simple" +version = "0.0.2" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = "=0.46.0" +wit-component = "=0.239.0" diff --git a/examples/resources-simple/resources.go b/examples/resources-simple/resources.go new file mode 100644 index 0000000..c5ce934 --- /dev/null +++ b/examples/resources-simple/resources.go @@ -0,0 +1,356 @@ +// Code generated by arcjet-gravity; DO NOT EDIT. + +package resources + +import "context" +import "errors" +import "github.com/tetratelabs/wazero" +import "github.com/tetratelabs/wazero/api" +import "sync" + +import _ "embed" + +//go:embed resources.wasm +var wasmFileResources []byte + +// ifaceFooerHandle is a handle to the fooer resource in the iface interface. +type ifaceFooerHandle uint32 + +type IfaceFooer interface { + GetX(ctx context.Context) uint32 + SetX(ctx context.Context, x uint32) + GetY(ctx context.Context) string + SetY(ctx context.Context, y string) +} + +type IResourcesIface[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]] interface { + NewFooer(ctx context.Context, x uint32, y string) TIfaceFooerValue +} + +// PIfaceFooer constrains a pointer to a type implementing the IfaceFooer interface. +type PIfaceFooer[TIfaceFooerValue any] interface { + *TIfaceFooerValue + IfaceFooer +} + +// ifaceFooerResourceTable is a resource table for fooer resources from the iface interface. +type ifaceFooerResourceTable[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]] struct { + mu sync.Mutex + nextHandle uint32 + table map[ifaceFooerHandle]*TIfaceFooerValue +} + +func newIfaceFooerResourceTable[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]]() *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer] { + return &ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]{ + nextHandle: 1, + table: make(map[ifaceFooerHandle]*TIfaceFooerValue), + } +} + +// Store adds a resource to the table and returns its handle. +func (t *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]) Store(resource TIfaceFooerValue) ifaceFooerHandle { + t.mu.Lock() + defer t.mu.Unlock() + handle := ifaceFooerHandle(t.nextHandle) + t.nextHandle++ + t.table[handle] = &resource + return handle +} + +// get returns a pointer to the resource from the table by its handle. +func (t *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]) get(handle ifaceFooerHandle) (PTIfaceFooer, bool) { + t.mu.Lock() + defer t.mu.Unlock() + resource, ok := t.table[handle] + if !ok { + var zero PTIfaceFooer + return zero, false + } + return resource, true +} + +// Get retrieves a resource from the table by its handle. +func (t *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]) Get(handle ifaceFooerHandle) (TIfaceFooerValue, bool) { + t.mu.Lock() + defer t.mu.Unlock() + resource, ok := t.table[handle] + if !ok { + var zero TIfaceFooerValue + return zero, false + } + return *resource, true +} + +// Remove deletes a resource from the table. +func (t *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]) Remove(handle ifaceFooerHandle) { + t.mu.Lock() + defer t.mu.Unlock() + delete(t.table, handle) +} + +type ResourcesFactory[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]] struct { + runtime wazero.Runtime + module wazero.CompiledModule + IfaceFooerResourceTable *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer] +} + +func NewResourcesFactory[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]]( + ctx context.Context, + iface IResourcesIface[TIfaceFooerValue, PTIfaceFooer], +) (*ResourcesFactory[TIfaceFooerValue, PTIfaceFooer], error) { + wazeroRuntime := wazero.NewRuntime(ctx) + // Initialize resource tables before host module instantiation + ifaceFooerResourceTable := newIfaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer]() + // Instantiate import host modules + _, err0 := wazeroRuntime.NewHostModuleBuilder("arcjet:resources/iface"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + arg2 uint32, + ) uint32{ + // GetArg { nth: 0 } + // U32FromI32 + result0 := api.DecodeU32(uint64(arg0)) + // GetArg { nth: 1 } + // GetArg { nth: 2 } + // StringLift + buf1, ok1 := mod.Memory().Read(arg1, arg2) + if !ok1 { + panic(errors.New("failed to read bytes from memory")) + } + str1 := string(buf1) + // CallInterface { func: Function { name: "[constructor]fooer", kind: Constructor(Id { idx: 0 }), params: [("x", U32), ("y", String)], result: Some(Id(Id { idx: 1 })), docs: Docs { contents: None }, stability: Unknown }, async_: false } + value2 := iface.NewFooer(ctx, result0, str1) + // HandleLower { handle: Own(Id { idx: 0 }), name: "fooer", ty: Id { idx: 1 } } + converted3 := uint32(ifaceFooerResourceTable.Store(value2)) + // Return { amt: 1, func: Function { name: "[constructor]fooer", kind: Constructor(Id { idx: 0 }), params: [("x", U32), ("y", String)], result: Some(Id(Id { idx: 1 })), docs: Docs { contents: None }, stability: Unknown } } + return converted3 + }). + Export("[constructor]fooer"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + ) uint32{ + // GetArg { nth: 0 } + // HandleLift { handle: Borrow(Id { idx: 0 }), name: "fooer", ty: Id { idx: 2 } } + converted0 := ifaceFooerHandle(arg0) + // CallInterface { func: Function { name: "[method]fooer.get-x", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 }))], result: Some(U32), docs: Docs { contents: None }, stability: Unknown }, async_: false } + resource1, ok1 := ifaceFooerResourceTable.get(converted0) + if !ok1 { + panic("invalid resource handle") + } + value1 := resource1.GetX(ctx) + // I32FromU32 + result2 := value1 + // Return { amt: 1, func: Function { name: "[method]fooer.get-x", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 }))], result: Some(U32), docs: Docs { contents: None }, stability: Unknown } } + return result2 + }). + Export("[method]fooer.get-x"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + ) { + // GetArg { nth: 0 } + // HandleLift { handle: Borrow(Id { idx: 0 }), name: "fooer", ty: Id { idx: 2 } } + converted0 := ifaceFooerHandle(arg0) + // GetArg { nth: 1 } + // U32FromI32 + result1 := api.DecodeU32(uint64(arg1)) + // CallInterface { func: Function { name: "[method]fooer.set-x", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 })), ("x", U32)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } + resource2, ok2 := ifaceFooerResourceTable.get(converted0) + if !ok2 { + panic("invalid resource handle") + } + resource2.SetX(ctx, result1) + // Return { amt: 0, func: Function { name: "[method]fooer.set-x", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 })), ("x", U32)], result: None, docs: Docs { contents: None }, stability: Unknown } } + }). + Export("[method]fooer.set-x"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + ) { + // GetArg { nth: 0 } + // HandleLift { handle: Borrow(Id { idx: 0 }), name: "fooer", ty: Id { idx: 2 } } + converted0 := ifaceFooerHandle(arg0) + // CallInterface { func: Function { name: "[method]fooer.get-y", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 }))], result: Some(String), docs: Docs { contents: None }, stability: Unknown }, async_: false } + resource1, ok1 := ifaceFooerResourceTable.get(converted0) + if !ok1 { + panic("invalid resource handle") + } + value1 := resource1.GetY(ctx) + // GetArg { nth: 1 } + // StringLower { realloc: Some("cabi_realloc") } + memory2 := mod.Memory() + realloc2 := mod.ExportedFunction("cabi_realloc") + ptr2, len2, err2 := writeString(ctx, value1, memory2, realloc2) + if err2 != nil { + panic(err2) + } + // LengthStore { offset: ptrsz } + mod.Memory().WriteUint32Le(uint32(arg1+4), uint32(len2)) + // PointerStore { offset: 0 } + mod.Memory().WriteUint32Le(uint32(arg1+0), uint32(ptr2)) + // Return { amt: 0, func: Function { name: "[method]fooer.get-y", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 }))], result: Some(String), docs: Docs { contents: None }, stability: Unknown } } + }). + Export("[method]fooer.get-y"). + NewFunctionBuilder(). + WithFunc(func( + ctx context.Context, + mod api.Module, + arg0 uint32, + arg1 uint32, + arg2 uint32, + ) { + // GetArg { nth: 0 } + // HandleLift { handle: Borrow(Id { idx: 0 }), name: "fooer", ty: Id { idx: 2 } } + converted0 := ifaceFooerHandle(arg0) + // GetArg { nth: 1 } + // GetArg { nth: 2 } + // StringLift + buf1, ok1 := mod.Memory().Read(arg1, arg2) + if !ok1 { + panic(errors.New("failed to read bytes from memory")) + } + str1 := string(buf1) + // CallInterface { func: Function { name: "[method]fooer.set-y", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 })), ("y", String)], result: None, docs: Docs { contents: None }, stability: Unknown }, async_: false } + resource2, ok2 := ifaceFooerResourceTable.get(converted0) + if !ok2 { + panic("invalid resource handle") + } + resource2.SetY(ctx, str1) + // Return { amt: 0, func: Function { name: "[method]fooer.set-y", kind: Method(Id { idx: 0 }), params: [("self", Id(Id { idx: 2 })), ("y", String)], result: None, docs: Docs { contents: None }, stability: Unknown } } + }). + Export("[method]fooer.set-y"). + NewFunctionBuilder(). + WithFunc(func(ctx context.Context, mod api.Module, arg0 uint32) { + // [resource-drop]: called when guest drops a resource + // + // With borrow-only parameters, guests never take ownership of host resources. + // Resources stay in host table until host explicitly removes them. + // This callback is a no-op since host controls the full lifecycle. + // + // Note: If we add owned parameter support in the future, this would need + // to implement ref-counting and state tracking to properly cleanup consumed resources. + _ = arg0 + }). + Export("[resource-drop]fooer"). + Instantiate(ctx) + if err0 != nil { + return nil, err0 + } + // Instantiate export resource management host modules + // Compiling the module takes a LONG time, so we want to do it once and hold + // onto it with the Runtime + module, err := wazeroRuntime.CompileModule(ctx, wasmFileResources) + if err != nil { + return nil, err + } + return &ResourcesFactory[TIfaceFooerValue, PTIfaceFooer]{ + runtime: wazeroRuntime, + module: module, + IfaceFooerResourceTable: ifaceFooerResourceTable, + }, nil +} + +func (f *ResourcesFactory[TIfaceFooerValue, PTIfaceFooer]) Instantiate(ctx context.Context) (*ResourcesInstance[TIfaceFooerValue, PTIfaceFooer], error) { + if module, err := f.runtime.InstantiateModule(ctx, f.module, wazero.NewModuleConfig()); err != nil { + return nil, err + } else { + return &ResourcesInstance[TIfaceFooerValue, PTIfaceFooer]{ + module: module, + IfaceFooerResourceTable: f.IfaceFooerResourceTable, + }, nil + } +} + +func (f *ResourcesFactory[TIfaceFooerValue, PTIfaceFooer]) Close(ctx context.Context) { + f.runtime.Close(ctx) +} + +type ResourcesInstance[TIfaceFooerValue any, PTIfaceFooer PIfaceFooer[TIfaceFooerValue]] struct { + module api.Module + IfaceFooerResourceTable *ifaceFooerResourceTable[TIfaceFooerValue, PTIfaceFooer] +} + +func (i *ResourcesInstance[TIfaceFooerValue, PTIfaceFooer]) Close(ctx context.Context) error { + if err := i.module.Close(ctx); err != nil { + return err + } + + return nil +} + +// writeString will put a Go string into the Wasm memory following the Component +// Model calling conventions, such as allocating memory with the realloc function +func writeString( + ctx context.Context, + s string, + memory api.Memory, + realloc api.Function, +) (uint64, uint64, error) { + if len(s) == 0 { + return 1, 0, nil + } + + results, err := realloc.Call(ctx, 0, 0, 1, uint64(len(s))) + if err != nil { + return 1, 0, err + } + ptr := results[0] + ok := memory.Write(uint32(ptr), []byte(s)) + if !ok { + return 1, 0, errors.New("failed to write string to wasm memory") + } + return uint64(ptr), uint64(len(s)), nil +} + +func (i *ResourcesInstance[TIfaceFooerValue, PTIfaceFooer]) UseFooer( + ctx context.Context, + foo ifaceFooerHandle, +) { + arg0 := foo + // GetArg { nth: 0 } + // HandleLower { handle: Borrow(Id { idx: 3 }), name: "fooer", ty: Id { idx: 4 } } + converted0 := uint32(arg0) + // CallWasm { name: "use-fooer", sig: WasmSignature { params: [I32], results: [], indirect_params: false, retptr: false } } + _, err1 := i.module.ExportedFunction("use-fooer").Call(ctx, uint64(converted0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + // Return { amt: 0, func: Function { name: "use-fooer", kind: Freestanding, params: [("foo", Id(Id { idx: 4 }))], result: None, docs: Docs { contents: None }, stability: Unknown } } +} + +func (i *ResourcesInstance[TIfaceFooerValue, PTIfaceFooer]) UseFooerReturnNew( + ctx context.Context, + foo ifaceFooerHandle, +) ifaceFooerHandle { + arg0 := foo + // GetArg { nth: 0 } + // HandleLower { handle: Borrow(Id { idx: 3 }), name: "fooer", ty: Id { idx: 4 } } + converted0 := uint32(arg0) + // CallWasm { name: "use-fooer-return-new", sig: WasmSignature { params: [I32], results: [I32], indirect_params: false, retptr: false } } + raw1, err1 := i.module.ExportedFunction("use-fooer-return-new").Call(ctx, uint64(converted0)) + // The return type doesn't contain an error so we panic if one is encountered + if err1 != nil { + panic(err1) + } + + results1 := raw1[0] + // HandleLift { handle: Own(Id { idx: 3 }), name: "fooer", ty: Id { idx: 5 } } + converted2 := ifaceFooerHandle(results1) + // Return { amt: 1, func: Function { name: "use-fooer-return-new", kind: Freestanding, params: [("foo", Id(Id { idx: 4 }))], result: Some(Id(Id { idx: 5 })), docs: Docs { contents: None }, stability: Unknown } } + return converted2 +} diff --git a/examples/resources-simple/resources_test.go b/examples/resources-simple/resources_test.go new file mode 100644 index 0000000..ddc151c --- /dev/null +++ b/examples/resources-simple/resources_test.go @@ -0,0 +1,185 @@ +package resources + +import ( + "context" + "testing" +) + +// foo is a resource type. +type foo struct { + x uint32 + y string +} + +// GetX implements the IFaceFooer interface. +func (f *foo) GetX(context.Context) uint32 { + return f.x +} + +// GetY implements the IFaceFooer interface. +func (f *foo) GetY(context.Context) string { + return f.y +} + +// SetX implements the IFaceFooer interface. +func (f *foo) SetX(_ context.Context, x uint32) { + f.x = x +} + +// SetY implements the IFaceFooer interface. +func (f *foo) SetY(_ context.Context, y string) { + f.y = y +} + +// iface is an implementation of the IFaceFooer interface. +type iface struct{} + +// NewFoo implements the IResourcesIFace interface. +func (iface) NewFooer(ctx context.Context, x uint32, y string) foo { + return foo{x: x, y: y} +} + +func TestResources(t *testing.T) { + ctx := context.Background() + + // Create factory with interface implementations. + // Resource types are inferred from the interface implementation. + // TODO: check if type inference works when using multiple interfaces and/or resources. + fac, err := NewResourcesFactory(ctx, &iface{}) + if err != nil { + t.Fatal(err) + } + defer fac.Close(ctx) + + ins, err := fac.Instantiate(ctx) + if err != nil { + t.Fatal(err) + } + defer ins.Close(ctx) + + // Create a new instance of our `foo` resource, and store it + // in the table. + f1 := foo{x: 42, y: "Hello"} + handle := fac.IfaceFooerResourceTable.Store(f1) + defer fac.IfaceFooerResourceTable.Remove(handle) + + // Call the exported function `use-fooer` on the module, passing + // the handle to the `foo` resource. + ins.UseFooer(ctx, handle) + + // Get a copy of the resource from the table + // and check that it has the expected values. + // Since this is a borrowed resource, it should have been + // modified. + f1Ptr, ok := fac.IfaceFooerResourceTable.Get(handle) + if !ok { + t.Errorf("expected resource to be present in table") + } + if f1Ptr.x != 43 { + t.Errorf("expected f1Ptr.x to be 43, got %d", f1Ptr.x) + } + if f1Ptr.y != "world" { + t.Errorf("expected f1Ptr.y to be 'world', got '%s'", f1Ptr.y) + } + // Make sure the original wasn't modified. + if f1.x != 42 { + t.Errorf("expected f1.x to be unmodified, got %d", f1.x) + } + if f1.y != "Hello" { + t.Errorf("expected f1.y to be unmodified, got '%s'", f1.y) + } + +} + +func TestGuestCreatedResources(t *testing.T) { + ctx := context.Background() + + // Create factory with interface implementation + fac, err := NewResourcesFactory(ctx, &iface{}) + if err != nil { + t.Fatal(err) + } + defer fac.Close(ctx) + + ins, err := fac.Instantiate(ctx) + if err != nil { + t.Fatal(err) + } + defer ins.Close(ctx) + + t.Run("UseFooerReturnNew_creates_host_resource_via_callback", func(t *testing.T) { + // Create a host resource + hostFoo := foo{x: 100, y: "host"} + handle := fac.IfaceFooerResourceTable.Store(hostFoo) + defer fac.IfaceFooerResourceTable.Remove(handle) + + // Call guest function that borrows host resource and returns NEW resource + // The guest calls Fooer::new() which calls back to the HOST constructor + newHandle := ins.UseFooerReturnNew(ctx, handle) + + // The returned handle should be different (it's a new resource) + if newHandle == handle { + t.Errorf("expected different handle for new resource, got same: %d", newHandle) + } + + // The new resource IS in the host's table because the guest called the host constructor + newResource, ok := fac.IfaceFooerResourceTable.Get(newHandle) + if !ok { + t.Errorf("new resource should be in host table (created via host constructor)") + } + + // Verify it has the expected values (x+1, "world") + if newResource.x != 101 { + t.Errorf("expected new resource x=101, got %d", newResource.x) + } + if newResource.y != "world" { + t.Errorf("expected new resource y='world', got %q", newResource.y) + } + + // We can pass the new resource back to other guest functions + ins.UseFooer(ctx, newHandle) + + // The original host resource should still be in the table (we only borrowed it) + _, ok = fac.IfaceFooerResourceTable.Get(handle) + if !ok { + t.Errorf("original host resource should still be in table after borrow") + } + + // Clean up the new resource + fac.IfaceFooerResourceTable.Remove(newHandle) + }) + + t.Run("Multiple_new_resources_independent", func(t *testing.T) { + // Create multiple host resources and get back multiple new resources + host1 := foo{x: 1, y: "one"} + host2 := foo{x: 2, y: "two"} + handle1 := fac.IfaceFooerResourceTable.Store(host1) + defer fac.IfaceFooerResourceTable.Remove(handle1) + handle2 := fac.IfaceFooerResourceTable.Store(host2) + defer fac.IfaceFooerResourceTable.Remove(handle2) + + // Get new resources from both (guest calls host constructor) + new1 := ins.UseFooerReturnNew(ctx, handle1) + defer fac.IfaceFooerResourceTable.Remove(new1) + new2 := ins.UseFooerReturnNew(ctx, handle2) + defer fac.IfaceFooerResourceTable.Remove(new2) + + // New handles should be different from each other + if new1 == new2 { + t.Errorf("expected different handles, got same: %d", new1) + } + + // Both new resources should work independently + ins.UseFooer(ctx, new1) + ins.UseFooer(ctx, new2) + + // Can create another new resource from the same host resource + new3 := ins.UseFooerReturnNew(ctx, handle1) + defer fac.IfaceFooerResourceTable.Remove(new3) + if new3 == new1 { + t.Errorf("expected different handle for different resource, got same: %d", new3) + } + + ins.UseFooer(ctx, new3) + }) +} diff --git a/examples/resources-simple/src/lib.rs b/examples/resources-simple/src/lib.rs new file mode 100644 index 0000000..d0aeff7 --- /dev/null +++ b/examples/resources-simple/src/lib.rs @@ -0,0 +1,17 @@ +wit_bindgen::generate!(); + +struct ResourcesWorld; +export!(ResourcesWorld); + +impl Guest for ResourcesWorld { + fn use_fooer(foo: &Fooer) { + let x = foo.get_x(); + foo.set_x(x + 1); + foo.set_y("world"); + } + + fn use_fooer_return_new(foo: &Fooer) -> Fooer { + let x = foo.get_x(); + Fooer::new(x + 1, "world") + } +} diff --git a/examples/resources-simple/wit/world.wit b/examples/resources-simple/wit/world.wit new file mode 100644 index 0000000..89bbb23 --- /dev/null +++ b/examples/resources-simple/wit/world.wit @@ -0,0 +1,19 @@ +package arcjet:resources; + +interface iface { + resource fooer { + constructor(x: u32, y: string); + get-x: func() -> u32; + set-x: func(x: u32); + get-y: func() -> string; + set-y: func(y: string); + } +} + +world resources { + import iface; + use iface.{fooer}; + + export use-fooer: func(foo: borrow); + export use-fooer-return-new: func(foo: borrow) -> fooer; +} diff --git a/examples/resources/Cargo.toml b/examples/resources/Cargo.toml new file mode 100644 index 0000000..9b378e7 --- /dev/null +++ b/examples/resources/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example-resources" +version = "0.0.2" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = "=0.46.0" +wit-component = "=0.239.0" diff --git a/examples/resources/resources_test.go b/examples/resources/resources_test.go new file mode 100644 index 0000000..ffe08e4 --- /dev/null +++ b/examples/resources/resources_test.go @@ -0,0 +1,354 @@ +package resources + +import ( + "context" + "testing" +) + +// ============================================================================ +// Host-side resource implementations for imported interfaces +// ============================================================================ + +// TypesAFoo implementation (host-provided) +type typesAFooResource struct { + x uint32 +} + +func (f *typesAFooResource) GetX(ctx context.Context) uint32 { + return f.x +} + +func (f *typesAFooResource) SetX(ctx context.Context, n uint32) { + f.x = n +} + +// TypesABar implementation (host-provided) +type typesABarResource struct { + value string +} + +func (f *typesABarResource) GetValue(ctx context.Context) string { + return f.value +} + +func (f *typesABarResource) Append(ctx context.Context, s string) { + f.value += s +} + +// TypesBFoo implementation (host-provided, different from types-a foo!) +type typesBFooResource struct { + y string +} + +func (f *typesBFooResource) GetY(ctx context.Context) string { + return f.y +} + +func (f *typesBFooResource) SetY(ctx context.Context, s string) { + f.y = s +} + +// TypesBBaz implementation (host-provided) +type typesBBazResource struct { + count uint32 +} + +func (f *typesBBazResource) Increment(ctx context.Context) { + f.count++ +} + +func (f *typesBBazResource) GetCount(ctx context.Context) uint32 { + return f.count +} + +// IResourcesTypesA implementation +type typesAImpl struct{} + +func (typesAImpl) NewFoo(ctx context.Context, x uint32) typesAFooResource { + return typesAFooResource{x: x} +} + +func (typesAImpl) NewBar(ctx context.Context, value string) typesABarResource { + return typesABarResource{value: value} +} + +func (typesAImpl) DoubleFooX(ctx context.Context, f *typesAFooResource) uint32 { + // This is called by guest WASM when it wants to use a host-provided resource + // The guest calls the freestanding function DoubleFooX with a borrowed foo handle + // which gets looked up from our resource table and passed here + return f.GetX(ctx) * 2 +} + +func (typesAImpl) MakeBar(ctx context.Context, value string) typesABarResource { + // This is called by guest WASM when it wants the host to create a bar resource + return typesABarResource{value: value} +} + +// IResourcesTypesB implementation +type typesBImpl struct{} + +func (typesBImpl) NewFoo(ctx context.Context, y string) typesBFooResource { + return typesBFooResource{y: y} +} + +func (typesBImpl) NewBaz(ctx context.Context, count uint32) typesBBazResource { + return typesBBazResource{count: count} +} + +func (typesBImpl) TripleBazCount(ctx context.Context, b *typesBBazResource) uint32 { + // This is called by guest WASM when it wants to use a host-provided resource + return b.GetCount(ctx) * 3 +} + +func (typesBImpl) MakeFoo(ctx context.Context, y string) typesBFooResource { + // This is called by guest WASM when it wants the host to create a foo resource + return typesBFooResource{y: y} +} + +// ============================================================================ +// Tests +// ============================================================================ + +func TestFreestandingFunctions(t *testing.T) { + ctx := context.Background() + + fac, err := NewResourcesFactory(ctx, &typesAImpl{}, &typesBImpl{}) + if err != nil { + t.Fatal(err) + } + defer fac.Close(ctx) + + ins, err := fac.Instantiate(ctx) + if err != nil { + t.Fatal(err) + } + defer ins.Close(ctx) + + // Note: We can't easily test DoubleFooX because it requires a guest-created foo resource, + // and we don't have a freestanding function in types-a that creates a foo. + // The constructor is guest-internal and shouldn't be called from the host. + + t.Run("MakeBar_creates_guest_resource", func(t *testing.T) { + // Call the guest's MakeBar function which creates a bar resource in the guest + // and returns an opaque handle + barHandle := ins.MakeBar(ctx, "hello") + + // Verify we got a non-zero handle (guest resources start from 1) + if barHandle == 0 { + t.Error("MakeBar returned zero handle, expected non-zero") + } + }) + + t.Run("MakeFoo_creates_guest_resource", func(t *testing.T) { + // Call the guest's MakeFoo function which creates a foo resource in the guest + // and returns an opaque handle + fooHandle := ins.MakeFoo(ctx, "test") + + // Verify we got a non-zero handle + if fooHandle == 0 { + t.Error("MakeFoo returned zero handle, expected non-zero") + } + }) + + t.Run("Multiple_guest_resources_independent", func(t *testing.T) { + // Create multiple guest bar resources + bar1 := ins.MakeBar(ctx, "first") + bar2 := ins.MakeBar(ctx, "second") + bar3 := ins.MakeBar(ctx, "third") + + // Verify different handles + if bar1 == bar2 || bar1 == bar3 || bar2 == bar3 { + t.Error("Expected different handles for different guest resources") + } + + // All handles should be non-zero + if bar1 == 0 || bar2 == 0 || bar3 == 0 { + t.Error("Expected non-zero handles from guest") + } + }) +} + +func TestHostResourceModification(t *testing.T) { + ctx := context.Background() + + fac, err := NewResourcesFactory(ctx, &typesAImpl{}, &typesBImpl{}) + if err != nil { + t.Fatal(err) + } + defer fac.Close(ctx) + + ins, err := fac.Instantiate(ctx) + if err != nil { + t.Fatal(err) + } + defer ins.Close(ctx) + + t.Run("Host_resource_state_preserved", func(t *testing.T) { + // Create a host resource + hostFoo := typesAFooResource{x: 100} + handle := fac.TypesAFooResourceTable.Store(hostFoo) + defer fac.TypesAFooResourceTable.Remove(handle) + + // Get it from the table and verify + retrieved, ok := fac.TypesAFooResourceTable.Get(handle) + if !ok { + t.Fatal("Failed to retrieve resource from table") + } + if retrieved.x != 100 { + t.Errorf("Retrieved resource has x=%d, want 100", retrieved.x) + } + + // Modify it + retrievedPtr, ok := fac.TypesAFooResourceTable.get(handle) + if !ok { + t.Fatal("Failed to get pointer to resource") + } + retrievedPtr.SetX(ctx, 200) + + // Verify modification + retrieved2, ok := fac.TypesAFooResourceTable.Get(handle) + if !ok { + t.Fatal("Failed to retrieve resource after modification") + } + if retrieved2.x != 200 { + t.Errorf("After SetX(200), x=%d, want 200", retrieved2.x) + } + }) +} + +func TestMultipleResources(t *testing.T) { + ctx := context.Background() + + fac, err := NewResourcesFactory(ctx, &typesAImpl{}, &typesBImpl{}) + if err != nil { + t.Fatal(err) + } + defer fac.Close(ctx) + + ins, err := fac.Instantiate(ctx) + if err != nil { + t.Fatal(err) + } + defer ins.Close(ctx) + + t.Run("Multiple_host_resources_independent", func(t *testing.T) { + // Create multiple host resources + foo1 := typesAFooResource{x: 10} + foo2 := typesAFooResource{x: 20} + foo3 := typesAFooResource{x: 30} + + handle1 := fac.TypesAFooResourceTable.Store(foo1) + defer fac.TypesAFooResourceTable.Remove(handle1) + handle2 := fac.TypesAFooResourceTable.Store(foo2) + defer fac.TypesAFooResourceTable.Remove(handle2) + handle3 := fac.TypesAFooResourceTable.Store(foo3) + defer fac.TypesAFooResourceTable.Remove(handle3) + + // Verify different handles + if handle1 == handle2 || handle1 == handle3 || handle2 == handle3 { + t.Error("Expected different handles for different resources") + } + + // Verify all were stored successfully + r1, ok1 := fac.TypesAFooResourceTable.Get(handle1) + r2, ok2 := fac.TypesAFooResourceTable.Get(handle2) + r3, ok3 := fac.TypesAFooResourceTable.Get(handle3) + + if !ok1 || !ok2 || !ok3 { + t.Fatal("Failed to retrieve one or more resources") + } + + // Verify independent state + if r1.x != 10 { + t.Errorf("resource1.x = %d, want 10", r1.x) + } + if r2.x != 20 { + t.Errorf("resource2.x = %d, want 20", r2.x) + } + if r3.x != 30 { + t.Errorf("resource3.x = %d, want 30", r3.x) + } + }) +} + +func TestResourceTableOperations(t *testing.T) { + ctx := context.Background() + + fac, err := NewResourcesFactory(ctx, &typesAImpl{}, &typesBImpl{}) + if err != nil { + t.Fatal(err) + } + defer fac.Close(ctx) + + t.Run("Store_and_Get", func(t *testing.T) { + foo := typesAFooResource{x: 42} + handle := fac.TypesAFooResourceTable.Store(foo) + + retrieved, ok := fac.TypesAFooResourceTable.Get(handle) + if !ok { + t.Fatal("Failed to retrieve stored resource") + } + if retrieved.x != 42 { + t.Errorf("Retrieved resource has x=%d, want 42", retrieved.x) + } + + fac.TypesAFooResourceTable.Remove(handle) + }) + + t.Run("Remove_makes_unavailable", func(t *testing.T) { + foo := typesAFooResource{x: 99} + handle := fac.TypesAFooResourceTable.Store(foo) + + fac.TypesAFooResourceTable.Remove(handle) + + _, ok := fac.TypesAFooResourceTable.Get(handle) + if ok { + t.Error("Expected resource to be unavailable after Remove") + } + }) + + t.Run("Multiple_tables_independent", func(t *testing.T) { + // Store in types-a foo table + fooA := typesAFooResource{x: 10} + handleA := fac.TypesAFooResourceTable.Store(fooA) + defer fac.TypesAFooResourceTable.Remove(handleA) + + // Store in types-a bar table + bar := typesABarResource{value: "test"} + handleBar := fac.TypesABarResourceTable.Store(bar) + defer fac.TypesABarResourceTable.Remove(handleBar) + + // Store in types-b foo table (different resource type!) + fooB := typesBFooResource{y: "hello"} + handleB := fac.TypesBFooResourceTable.Store(fooB) + defer fac.TypesBFooResourceTable.Remove(handleB) + + // Store in types-b baz table + baz := typesBBazResource{count: 5} + handleBaz := fac.TypesBBazResourceTable.Store(baz) + defer fac.TypesBBazResourceTable.Remove(handleBaz) + + // Verify all can be retrieved independently + retrievedA, okA := fac.TypesAFooResourceTable.Get(handleA) + retrievedBar, okBar := fac.TypesABarResourceTable.Get(handleBar) + retrievedB, okB := fac.TypesBFooResourceTable.Get(handleB) + retrievedBaz, okBaz := fac.TypesBBazResourceTable.Get(handleBaz) + + if !okA || !okBar || !okB || !okBaz { + t.Fatal("Failed to retrieve resources from different tables") + } + + if retrievedA.x != 10 { + t.Errorf("types-a foo: got x=%d, want 10", retrievedA.x) + } + if retrievedBar.value != "test" { + t.Errorf("types-a bar: got value=%q, want %q", retrievedBar.value, "test") + } + if retrievedB.y != "hello" { + t.Errorf("types-b foo: got y=%q, want %q", retrievedB.y, "hello") + } + if retrievedBaz.count != 5 { + t.Errorf("types-b baz: got count=%d, want 5", retrievedBaz.count) + } + }) +} diff --git a/examples/resources/src/lib.rs b/examples/resources/src/lib.rs new file mode 100644 index 0000000..bcdfa1b --- /dev/null +++ b/examples/resources/src/lib.rs @@ -0,0 +1,127 @@ +use std::cell::{Cell, RefCell}; + +use crate::exports::arcjet::resources::{ + types_a::{GuestBar as _, GuestFoo as _}, + types_b::{GuestBaz, GuestFoo as _}, +}; + +wit_bindgen::generate!({ + world: "resources", +}); + +struct ResourcesWorld; + +export!(ResourcesWorld); + +// Implementation for types-a::foo +struct FooA { + x: Cell, +} + +impl exports::arcjet::resources::types_a::GuestFoo for FooA { + fn new(x: u32) -> Self { + Self { x: Cell::new(x) } + } + + fn get_x(&self) -> u32 { + self.x.get() + } + + fn set_x(&self, n: u32) { + self.x.set(n); + } +} + +// Implementation for types-a::bar +struct Bar { + value: RefCell, +} + +impl exports::arcjet::resources::types_a::GuestBar for Bar { + fn new(value: String) -> Self { + Self { + value: RefCell::new(value), + } + } + + fn get_value(&self) -> String { + self.value.borrow().clone() + } + + fn append(&self, s: String) { + self.value.borrow_mut().push_str(&s); + } +} + +// Implementation for types-b::foo (different from types-a::foo!) +struct FooB { + y: RefCell, +} + +impl exports::arcjet::resources::types_b::GuestFoo for FooB { + fn new(y: String) -> Self { + Self { y: RefCell::new(y) } + } + + fn get_y(&self) -> String { + self.y.borrow().clone() + } + + fn set_y(&self, s: String) { + self.y.replace(s); + } +} + +// Implementation for types-b::baz +struct Baz { + count: Cell, +} + +impl exports::arcjet::resources::types_b::GuestBaz for Baz { + fn new(count: u32) -> Self { + Self { + count: Cell::new(count), + } + } + + fn increment(&self) { + self.count.update(|c| c + 1); + } + + fn get_count(&self) -> u32 { + self.count.get() + } +} + +// Guest implementations for both interfaces +impl exports::arcjet::resources::types_a::Guest for ResourcesWorld { + type Foo = FooA; + type Bar = Bar; + + // Function that takes a host-provided resource (import side) + fn double_foo_x(f: exports::arcjet::resources::types_a::FooBorrow<'_>) -> u32 { + // Call the host-provided foo's get_x method and double it + f.get::().get_x() * 2 + } + + // Function that creates and returns a guest resource (export side) + fn make_bar(value: String) -> exports::arcjet::resources::types_a::Bar { + exports::arcjet::resources::types_a::Bar::new(Bar::new(value)) + } +} + +impl exports::arcjet::resources::types_b::Guest for ResourcesWorld { + type Foo = FooB; + type Baz = Baz; + + // Function that takes a host-provided resource (import side) + fn triple_baz_count(b: exports::arcjet::resources::types_b::BazBorrow<'_>) -> u32 { + // Call the host-provided baz's get_count method and triple it + b.get::().get_count() * 3 + } + + // Function that creates and returns a guest resource (export side) + fn make_foo(y: String) -> exports::arcjet::resources::types_b::Foo { + exports::arcjet::resources::types_b::Foo::new(FooB::new(y)) + } +} diff --git a/examples/resources/wit/resources.wit b/examples/resources/wit/resources.wit new file mode 100644 index 0000000..d57a44b --- /dev/null +++ b/examples/resources/wit/resources.wit @@ -0,0 +1,54 @@ +package arcjet:resources; + +// First interface with foo and bar resources +interface types-a { + resource foo { + constructor(x: u32); + get-x: func() -> u32; + set-x: func(n: u32); + } + + resource bar { + constructor(value: string); + get-value: func() -> string; + append: func(s: string); + } + + // Function that takes a host-provided resource and uses it + // (import side: host provides foo, guest consumes it) + double-foo-x: func(f: borrow) -> u32; + + // Function that creates and returns a guest resource + // (export side: guest creates bar, host consumes it) + make-bar: func(value: string) -> bar; +} + +// Second interface with foo (name clash!) and baz resources +interface types-b { + resource foo { + constructor(y: string); + get-y: func() -> string; + set-y: func(s: string); + } + + resource baz { + constructor(count: u32); + increment: func(); + get-count: func() -> u32; + } + + // Function that takes a host-provided resource and uses it + // (import side: host provides baz, guest consumes it) + triple-baz-count: func(b: borrow) -> u32; + + // Function that creates and returns a guest resource + // (export side: guest creates foo, host consumes it) + make-foo: func(y: string) -> foo; +} + +world resources { + import types-a; + import types-b; + export types-a; + export types-b; +} diff --git a/examples/tuples/Cargo.toml b/examples/tuples/Cargo.toml new file mode 100644 index 0000000..4e22e96 --- /dev/null +++ b/examples/tuples/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example-tuples" +version = "0.0.2" +edition = "2024" + +[lib] +crate-type = ["cdylib"] + +[dependencies] +wit-bindgen = "=0.46.0" +wit-component = "=0.239.0" diff --git a/examples/tuples/basic_test.go b/examples/tuples/basic_test.go new file mode 100644 index 0000000..cbc4372 --- /dev/null +++ b/examples/tuples/basic_test.go @@ -0,0 +1,49 @@ +package tuples + +import ( + "testing" +) + +func TestCustomTupleFunc(t *testing.T) { + fac, err := NewTuplesFactory(t.Context(), struct{}{}) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + tuple := CustomTuple{F0: 0, F1: 1, F2: "2"} + actual := ins.CustomTupleFunc(t.Context(), tuple) + if actual != tuple { + t.Errorf("expected: %v, but got: %v", tuple, actual) + } +} + +func TestAnonymousTupleFunc(t *testing.T) { + fac, err := NewTuplesFactory(t.Context(), struct{}{}) + if err != nil { + t.Fatal(err) + } + defer fac.Close(t.Context()) + + ins, err := fac.Instantiate(t.Context()) + if err != nil { + t.Fatal(err) + } + defer ins.Close(t.Context()) + + tuple := struct { + F0 uint32 + F1 float64 + F2 string + }{F0: 0, F1: 1, F2: "2"} + actual := ins.AnonymousTupleFunc(t.Context(), tuple) + if actual != tuple { + t.Errorf("expected: %v, but got: %v", tuple, actual) + } +} diff --git a/examples/tuples/src/lib.rs b/examples/tuples/src/lib.rs new file mode 100644 index 0000000..2b039f2 --- /dev/null +++ b/examples/tuples/src/lib.rs @@ -0,0 +1,16 @@ +wit_bindgen::generate!({ + world: "tuples", +}); + +struct TuplesWorld; + +export!(TuplesWorld); + +impl Guest for TuplesWorld { + fn custom_tuple_func(t: (u32, f64, String)) -> (u32, f64, String) { + t + } + fn anonymous_tuple_func(t: (u32, f64, String)) -> (u32, f64, String) { + t + } +} diff --git a/examples/tuples/wit/tuple.wit b/examples/tuples/wit/tuple.wit new file mode 100644 index 0000000..c01e65d --- /dev/null +++ b/examples/tuples/wit/tuple.wit @@ -0,0 +1,13 @@ +package arcjet:tuples; + +interface types { + type custom-tuple = tuple; +} + +world tuples { + import types; + use types.{custom-tuple}; + + export custom-tuple-func: func(t: custom-tuple) -> custom-tuple; + export anonymous-tuple-func: func(t: tuple) -> tuple; +} diff --git a/go.mod b/go.mod index f7d103a..fc1db3b 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,13 @@ go 1.24.3 require github.com/tetratelabs/wazero v1.9.0 require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/stretchr/testify v1.11.1 // indirect golang.org/x/mod v0.24.0 // indirect golang.org/x/sync v0.14.0 // indirect golang.org/x/tools v0.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) tool golang.org/x/tools/cmd/goimports diff --git a/go.sum b/go.sum index 1ca1e1b..e5c4925 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,9 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/tetratelabs/wazero v1.9.0 h1:IcZ56OuxrtaEz8UYNRHBrUa9bYeX9oVY93KspZZBf/I= github.com/tetratelabs/wazero v1.9.0/go.mod h1:TSbcXCfFP0L2FGkRPxHphadXPjo1T6W+CseNNY7EkjM= golang.org/x/mod v0.24.0 h1:ZfthKaKaT4NrhGVZHO1/WDTwGES4De8KtWO0SIbNJMU= @@ -6,3 +12,6 @@ golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/tools v0.33.0 h1:4qz2S3zmRxbGIhDIAgjxvFutSvH5EfnsYrRBj0UI0bc= golang.org/x/tools v0.33.0/go.mod h1:CIJMaWEY88juyUfo7UbgPqbC8rU2OqfAV1h2Qp0oMYI= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=