1717
1818use crate :: backend:: Id ;
1919use crate :: { Backend , Registry } ;
20+ use anyhow:: anyhow;
2021use std:: collections:: HashMap ;
2122use std:: hash:: Hash ;
2223use std:: { fmt, str:: FromStr } ;
@@ -54,29 +55,63 @@ impl<'a> WasiNnView<'a> {
5455 }
5556}
5657
57- pub enum Error {
58+ /// A wasi-nn error; this appears on the Wasm side as a component model
59+ /// resource.
60+ #[ derive( Debug ) ]
61+ pub struct Error {
62+ code : ErrorCode ,
63+ data : anyhow:: Error ,
64+ }
65+
66+ /// Construct an [`Error`] resource and immediately return it.
67+ ///
68+ /// The WIT specification currently relies on "errors as resources;" this helper
69+ /// macro hides some of that complexity. If [#75] is adopted ("errors as
70+ /// records"), this macro is no longer necessary.
71+ ///
72+ /// [#75]: https://github.com/WebAssembly/wasi-nn/pull/75
73+ macro_rules! bail {
74+ ( $self: ident, $code: expr, $data: expr) => {
75+ let e = Error {
76+ code: $code,
77+ data: $data. into( ) ,
78+ } ;
79+ tracing:: error!( "failure: {e:?}" ) ;
80+ let r = $self. table. push( e) ?;
81+ return Ok ( Err ( r) ) ;
82+ } ;
83+ }
84+
85+ impl From < wasmtime:: component:: ResourceTableError > for Error {
86+ fn from ( error : wasmtime:: component:: ResourceTableError ) -> Self {
87+ Self {
88+ code : ErrorCode :: Trap ,
89+ data : error. into ( ) ,
90+ }
91+ }
92+ }
93+
94+ /// The list of error codes available to the `wasi-nn` API; this should match
95+ /// what is specified in WIT.
96+ #[ derive( Debug ) ]
97+ pub enum ErrorCode {
5898 /// Caller module passed an invalid argument.
5999 InvalidArgument ,
60100 /// Invalid encoding.
61101 InvalidEncoding ,
62102 /// The operation timed out.
63103 Timeout ,
64- /// Runtime Error .
104+ /// Runtime error .
65105 RuntimeError ,
66106 /// Unsupported operation.
67107 UnsupportedOperation ,
68108 /// Graph is too large.
69109 TooLarge ,
70110 /// Graph not found.
71111 NotFound ,
72- /// A runtime error occurred that we should trap on; see `StreamError`.
73- Trap ( anyhow:: Error ) ,
74- }
75-
76- impl From < wasmtime:: component:: ResourceTableError > for Error {
77- fn from ( error : wasmtime:: component:: ResourceTableError ) -> Self {
78- Self :: Trap ( error. into ( ) )
79- }
112+ /// A runtime error that Wasmtime should trap on; this will not appear in
113+ /// the WIT specification.
114+ Trap ,
80115}
81116
82117/// Generate the traits and types from the `wasi-nn` WIT specification.
@@ -91,6 +126,7 @@ mod gen_ {
91126 "wasi:nn/graph/graph" : crate :: Graph ,
92127 "wasi:nn/tensor/tensor" : crate :: Tensor ,
93128 "wasi:nn/inference/graph-execution-context" : crate :: ExecutionContext ,
129+ "wasi:nn/errors/error" : super :: Error ,
94130 } ,
95131 trappable_error_type: {
96132 "wasi:nn/errors/error" => super :: Error ,
@@ -131,36 +167,45 @@ impl gen::graph::Host for WasiNnView<'_> {
131167 builders : Vec < GraphBuilder > ,
132168 encoding : GraphEncoding ,
133169 target : ExecutionTarget ,
134- ) -> Result < Resource < crate :: Graph > , Error > {
170+ ) -> wasmtime :: Result < Result < Resource < Graph > , Resource < Error > > > {
135171 tracing:: debug!( "load {encoding:?} {target:?}" ) ;
136172 if let Some ( backend) = self . ctx . backends . get_mut ( & encoding) {
137173 let slices = builders. iter ( ) . map ( |s| s. as_slice ( ) ) . collect :: < Vec < _ > > ( ) ;
138174 match backend. load ( & slices, target. into ( ) ) {
139175 Ok ( graph) => {
140176 let graph = self . table . push ( graph) ?;
141- Ok ( graph)
177+ Ok ( Ok ( graph) )
142178 }
143179 Err ( error) => {
144- tracing:: error!( "failed to load graph: {error:?}" ) ;
145- Err ( Error :: RuntimeError )
180+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
146181 }
147182 }
148183 } else {
149- Err ( Error :: InvalidEncoding )
184+ bail ! (
185+ self ,
186+ ErrorCode :: InvalidEncoding ,
187+ anyhow!( "unable to find a backend for this encoding" )
188+ ) ;
150189 }
151190 }
152191
153- fn load_by_name ( & mut self , name : String ) -> Result < Resource < Graph > , Error > {
192+ fn load_by_name (
193+ & mut self ,
194+ name : String ,
195+ ) -> wasmtime:: Result < Result < Resource < Graph > , Resource < Error > > > {
154196 use core:: result:: Result :: * ;
155197 tracing:: debug!( "load by name {name:?}" ) ;
156198 let registry = & self . ctx . registry ;
157199 if let Some ( graph) = registry. get ( & name) {
158200 let graph = graph. clone ( ) ;
159201 let graph = self . table . push ( graph) ?;
160- Ok ( graph)
202+ Ok ( Ok ( graph) )
161203 } else {
162- tracing:: error!( "failed to find graph with name: {name}" ) ;
163- Err ( Error :: NotFound )
204+ bail ! (
205+ self ,
206+ ErrorCode :: NotFound ,
207+ anyhow!( "failed to find graph with name: {name}" )
208+ ) ;
164209 }
165210 }
166211}
@@ -169,18 +214,17 @@ impl gen::graph::HostGraph for WasiNnView<'_> {
169214 fn init_execution_context (
170215 & mut self ,
171216 graph : Resource < Graph > ,
172- ) -> Result < Resource < GraphExecutionContext > , Error > {
217+ ) -> wasmtime :: Result < Result < Resource < GraphExecutionContext > , Resource < Error > > > {
173218 use core:: result:: Result :: * ;
174219 tracing:: debug!( "initialize execution context" ) ;
175220 let graph = self . table . get ( & graph) ?;
176221 match graph. init_execution_context ( ) {
177222 Ok ( exec_context) => {
178223 let exec_context = self . table . push ( exec_context) ?;
179- Ok ( exec_context)
224+ Ok ( Ok ( exec_context) )
180225 }
181226 Err ( error) => {
182- tracing:: error!( "failed to initialize execution context: {error:?}" ) ;
183- Err ( Error :: RuntimeError )
227+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
184228 }
185229 }
186230 }
@@ -197,47 +241,46 @@ impl gen::inference::HostGraphExecutionContext for WasiNnView<'_> {
197241 exec_context : Resource < GraphExecutionContext > ,
198242 name : String ,
199243 tensor : Resource < Tensor > ,
200- ) -> Result < ( ) , Error > {
244+ ) -> wasmtime :: Result < Result < ( ) , Resource < Error > > > {
201245 let tensor = self . table . get ( & tensor) ?;
202246 tracing:: debug!( "set input {name:?}: {tensor:?}" ) ;
203247 let tensor = tensor. clone ( ) ; // TODO: avoid copying the tensor
204248 let exec_context = self . table . get_mut ( & exec_context) ?;
205- if let Err ( e) = exec_context. set_input ( Id :: Name ( name) , & tensor) {
206- tracing:: error!( "failed to set input: {e:?}" ) ;
207- Err ( Error :: InvalidArgument )
249+ if let Err ( error) = exec_context. set_input ( Id :: Name ( name) , & tensor) {
250+ bail ! ( self , ErrorCode :: InvalidArgument , error) ;
208251 } else {
209- Ok ( ( ) )
252+ Ok ( Ok ( ( ) ) )
210253 }
211254 }
212255
213- fn compute ( & mut self , exec_context : Resource < GraphExecutionContext > ) -> Result < ( ) , Error > {
256+ fn compute (
257+ & mut self ,
258+ exec_context : Resource < GraphExecutionContext > ,
259+ ) -> wasmtime:: Result < Result < ( ) , Resource < Error > > > {
214260 let exec_context = & mut self . table . get_mut ( & exec_context) ?;
215261 tracing:: debug!( "compute" ) ;
216262 match exec_context. compute ( ) {
217- Ok ( ( ) ) => Ok ( ( ) ) ,
263+ Ok ( ( ) ) => Ok ( Ok ( ( ) ) ) ,
218264 Err ( error) => {
219- tracing:: error!( "failed to compute: {error:?}" ) ;
220- Err ( Error :: RuntimeError )
265+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
221266 }
222267 }
223268 }
224269
225- #[ doc = r" Extract the outputs after inference." ]
226270 fn get_output (
227271 & mut self ,
228272 exec_context : Resource < GraphExecutionContext > ,
229273 name : String ,
230- ) -> Result < Resource < Tensor > , Error > {
274+ ) -> wasmtime :: Result < Result < Resource < Tensor > , Resource < Error > > > {
231275 let exec_context = self . table . get_mut ( & exec_context) ?;
232276 tracing:: debug!( "get output {name:?}" ) ;
233277 match exec_context. get_output ( Id :: Name ( name) ) {
234278 Ok ( tensor) => {
235279 let tensor = self . table . push ( tensor) ?;
236- Ok ( tensor)
280+ Ok ( Ok ( tensor) )
237281 }
238282 Err ( error) => {
239- tracing:: error!( "failed to get output: {error:?}" ) ;
240- Err ( Error :: RuntimeError )
283+ bail ! ( self , ErrorCode :: RuntimeError , error) ;
241284 }
242285 }
243286 }
@@ -285,21 +328,51 @@ impl gen::tensor::HostTensor for WasiNnView<'_> {
285328 }
286329}
287330
288- impl gen:: tensor:: Host for WasiNnView < ' _ > { }
331+ impl gen:: errors:: HostError for WasiNnView < ' _ > {
332+ fn new (
333+ & mut self ,
334+ _code : gen:: errors:: ErrorCode ,
335+ _data : String ,
336+ ) -> wasmtime:: Result < Resource < Error > > {
337+ unimplemented ! ( "this should be removed; see https://github.com/WebAssembly/wasi-nn/pull/76" )
338+ }
339+
340+ fn code ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < gen:: errors:: ErrorCode > {
341+ let error = self . table . get ( & error) ?;
342+ match error. code {
343+ ErrorCode :: InvalidArgument => Ok ( gen:: errors:: ErrorCode :: InvalidArgument ) ,
344+ ErrorCode :: InvalidEncoding => Ok ( gen:: errors:: ErrorCode :: InvalidEncoding ) ,
345+ ErrorCode :: Timeout => Ok ( gen:: errors:: ErrorCode :: Timeout ) ,
346+ ErrorCode :: RuntimeError => Ok ( gen:: errors:: ErrorCode :: RuntimeError ) ,
347+ ErrorCode :: UnsupportedOperation => Ok ( gen:: errors:: ErrorCode :: UnsupportedOperation ) ,
348+ ErrorCode :: TooLarge => Ok ( gen:: errors:: ErrorCode :: TooLarge ) ,
349+ ErrorCode :: NotFound => Ok ( gen:: errors:: ErrorCode :: NotFound ) ,
350+ ErrorCode :: Trap => Err ( anyhow ! ( error. data. to_string( ) ) ) ,
351+ }
352+ }
353+
354+ fn data ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < String > {
355+ let error = self . table . get ( & error) ?;
356+ Ok ( error. data . to_string ( ) )
357+ }
358+
359+ fn drop ( & mut self , error : Resource < Error > ) -> wasmtime:: Result < ( ) > {
360+ self . table . delete ( error) ?;
361+ Ok ( ( ) )
362+ }
363+ }
364+
289365impl gen:: errors:: Host for WasiNnView < ' _ > {
290- fn convert_error ( & mut self , err : Error ) -> wasmtime:: Result < gen:: errors:: Error > {
291- match err {
292- Error :: InvalidArgument => Ok ( gen:: errors:: Error :: InvalidArgument ) ,
293- Error :: InvalidEncoding => Ok ( gen:: errors:: Error :: InvalidEncoding ) ,
294- Error :: Timeout => Ok ( gen:: errors:: Error :: Timeout ) ,
295- Error :: RuntimeError => Ok ( gen:: errors:: Error :: RuntimeError ) ,
296- Error :: UnsupportedOperation => Ok ( gen:: errors:: Error :: UnsupportedOperation ) ,
297- Error :: TooLarge => Ok ( gen:: errors:: Error :: TooLarge ) ,
298- Error :: NotFound => Ok ( gen:: errors:: Error :: NotFound ) ,
299- Error :: Trap ( e) => Err ( e) ,
366+ fn convert_error ( & mut self , err : Error ) -> wasmtime:: Result < Error > {
367+ if matches ! ( err. code, ErrorCode :: Trap ) {
368+ Err ( err. data )
369+ } else {
370+ Ok ( err)
300371 }
301372 }
302373}
374+
375+ impl gen:: tensor:: Host for WasiNnView < ' _ > { }
303376impl gen:: inference:: Host for WasiNnView < ' _ > { }
304377
305378impl Hash for gen:: graph:: GraphEncoding {
0 commit comments