@@ -15,8 +15,20 @@ world ml {
1515 import errors ;
1616}
1717
18+ /// Inference is performed on a specific `device` .
19+ interface device {
20+ /// Define where tensors reside and graphs execute.
21+ enum location {
22+ cpu ,
23+ gpu ,
24+ tpu
25+ }
26+ }
27+
1828/// All inputs and outputs to an ML inference are represented as `tensor` s.
1929interface tensor {
30+ use device . {location };
31+
2032 /// The dimensions of a tensor.
2133 ///
2234 /// The array length matches the tensor rank and each element in the array describes the size of
@@ -44,8 +56,8 @@ interface tensor {
4456 type tensor-data = list <u8 >;
4557
4658 resource tensor {
47- constructor ( dimensions : tensor-dimensions , ty : tensor-type , data : tensor-data ,
48- location : option < execution-target > );
59+ /// Construct a tensor that lives on the host CPU.
60+ constructor ( dimensions : tensor-dimensions , ty : tensor-type , data : tensor-data );
4961
5062 // Describe the size of the tensor (e.g., 2x2x2x2 -> [2, 2, 2, 2]). To represent a tensor
5163 // containing a single value, use `[1]` for the tensor dimensions.
@@ -55,7 +67,7 @@ interface tensor {
5567 ty : func () -> tensor-type ;
5668
5769 // Describe where the tensor is currently located (e.g., `cpu`, `gpu`, `tpu`).
58- location : func () -> execution-target ;
70+ location : func () -> location ;
5971
6072 // Return the tensor data. If the tensor is located on a device other than the CPU, this
6173 // operation may result in an expensive data copy operation.
@@ -74,8 +86,9 @@ interface tensor {
7486/// framework (e.g., TensorFlow):
7587interface graph {
7688 use errors . {error };
77- use tensor . { tensor };
89+ use device . { location };
7890 use inference . {graph-execution-context };
91+ use tensor . {tensor };
7992
8093 /// An execution graph for performing inference (i.e., a model).
8194 resource graph {
@@ -93,21 +106,15 @@ interface graph {
93106 autodetect ,
94107 }
95108
96- /// Define where the graph should be executed.
97- enum execution-target {
98- cpu ,
99- gpu ,
100- tpu
101- }
102-
103109 /// The graph initialization data.
104110 ///
105111 /// This gets bundled up into an array of buffers because implementing backends may encode their
106112 /// graph IR in parts (e.g., OpenVINO stores its IR and weights separately).
107113 type graph-builder = list <u8 >;
108114
109- /// Load a `graph` from an opaque sequence of bytes to use for inference.
110- load : func (builder : list <graph-builder >, encoding : graph-encoding , target : execution-target ) -> result <graph , error >;
115+ /// Load a `graph` from an opaque sequence of bytes to use for inference on the specified device
116+ /// `location` .
117+ load : func (builder : list <graph-builder >, encoding : graph-encoding , location : location ) -> result <graph , error >;
111118
112119 /// Load a `graph` by name.
113120 ///
@@ -128,6 +135,11 @@ interface inference {
128135 /// TODO: this may no longer be necessary in WIT
129136 /// (https://github.com/WebAssembly/wasi-nn/issues/43)
130137 resource graph-execution-context {
138+ /// Load a tensor using the graph context. Unlike the `tensor` constructor, this function
139+ /// will co-locate the tensor data on a specific device using the graph's underlying
140+ /// backend; this may avoid some copies, improving performance.
141+ load-tensor : func (dimensions : tensor-dimensions , ty : tensor-type , data : tensor-data ) -> result <tensor , error >;
142+
131143 /// Define the inputs to use for inference.
132144 set-input : func (name : string , tensor : tensor ) -> result <_ , error >;
133145
0 commit comments