Skip to content

Commit 959c74a

Browse files
committed
feat: added slice onnx import
1 parent 36ed65a commit 959c74a

File tree

12 files changed

+313
-3
lines changed

12 files changed

+313
-3
lines changed

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ represent the corresponding Burn Op.
171171
| [Sin][164] |||
172172
| [Sinh][165] |||
173173
| [Size][166] |||
174-
| [Slice][167] | ||
174+
| [Slice][167] | ||
175175
| [Softmax][168] |||
176176
| [SoftmaxCrossEntropyLoss][169] |||
177177
| [Softplus][170] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ fn main() {
6969
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
7070
.input("tests/pow/pow.onnx")
7171
.input("tests/pow/pow_int.onnx")
72+
.input("tests/slice/slice.onnx")
7273
.input("tests/sum/sum.onnx")
7374
.input("tests/sum/sum_int.onnx")
7475
.input("tests/unsqueeze/unsqueeze.onnx")

crates/burn-import/onnx-tests/tests/onnx_tests.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ include_models!(
7171
sigmoid,
7272
sign,
7373
sin,
74+
slice,
7475
softmax,
7576
sqrt,
7677
sub_int,
@@ -459,6 +460,19 @@ mod tests {
459460
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
460461
}
461462

463+
#[test]
464+
fn slice() {
465+
let model: slice::Model<Backend> = slice::Model::default();
466+
let device = Default::default();
467+
468+
let input =
469+
Tensor::<Backend, 1>::from_floats([1., 2., 3., 4., 5., 6., 7., 8., 9., 10.], &device);
470+
let output = model.forward(input);
471+
let expected = Data::from([1., 2., 3., 4., 5.]);
472+
473+
assert_eq!(output.to_data(), expected);
474+
}
475+
462476
#[test]
463477
fn softmax() {
464478
// Initialize the model without weights (because the exported file does not contain them)
Binary file not shown.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: onnx-tests/tests/slice/slice.onnx
4+
5+
import onnx
6+
from onnx import helper, TensorProto
7+
8+
def main() -> None:
9+
# Starts
10+
starts_val = [0] # Example shape value
11+
starts_tensor = helper.make_tensor(
12+
name="starts",
13+
data_type=TensorProto.INT64,
14+
dims=[len(starts_val)],
15+
vals=starts_val,
16+
)
17+
starts_node = helper.make_node(
18+
"Constant",
19+
name="starts_constant",
20+
inputs=[],
21+
outputs=["starts"],
22+
value=starts_tensor,
23+
)
24+
25+
# Ends
26+
ends_val = [5] # Example shape value
27+
ends_tensor = helper.make_tensor(
28+
name="ends",
29+
data_type=TensorProto.INT64,
30+
dims=[len(ends_val)],
31+
vals=ends_val,
32+
)
33+
ends_node = helper.make_node(
34+
"Constant",
35+
name="ends_constant",
36+
inputs=[],
37+
outputs=["ends"],
38+
value=ends_tensor,
39+
)
40+
41+
# Axes
42+
axes_val = [0] # Example shape value
43+
axes_tensor = helper.make_tensor(
44+
name="axes",
45+
data_type=TensorProto.INT64,
46+
dims=[len(axes_val)],
47+
vals=axes_val,
48+
)
49+
axes_node = helper.make_node(
50+
"Constant",
51+
name="axes_constant",
52+
inputs=[],
53+
outputs=["axes"],
54+
value=axes_tensor,
55+
)
56+
57+
# Steps
58+
steps_val = [1] # Example shape value
59+
steps_tensor = helper.make_tensor(
60+
name="steps",
61+
data_type=TensorProto.INT64,
62+
dims=[len(steps_val)],
63+
vals=steps_val,
64+
)
65+
steps_node = helper.make_node(
66+
"Constant",
67+
name="steps_constant",
68+
inputs=[],
69+
outputs=["steps"],
70+
value=steps_tensor,
71+
)
72+
73+
# Define the Slice node that uses the outputs from the constant nodes
74+
slice_node = helper.make_node(
75+
"Slice",
76+
name="slice_node",
77+
inputs=["input_tensor", "starts", "ends", "axes", "steps"],
78+
outputs=["output"],
79+
)
80+
81+
# Create the graph
82+
graph_def = helper.make_graph(
83+
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
84+
name="SliceGraph",
85+
inputs=[
86+
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [10]),
87+
],
88+
outputs=[
89+
helper.make_tensor_value_info("output", TensorProto.FLOAT, [5])
90+
],
91+
)
92+
93+
# Create the model
94+
model_def = helper.make_model(graph_def, producer_name="slice")
95+
96+
# Save the model to a file
97+
onnx.save(model_def, "slice.onnx")
98+
99+
100+
if __name__ == "__main__":
101+
main()

crates/burn-import/src/burn/node/base.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use super::{
77
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
88
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
99
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
10-
reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
10+
reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
1111
unsqueeze::UnsqueezeNode,
1212
};
1313
use crate::burn::{BurnImports, Scope, Type};
@@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
102102
MaxPool2d(MaxPool2dNode),
103103
Range(RangeNode),
104104
Reshape(ReshapeNode),
105+
Slice(SliceNode),
105106
Squeeze(SqueezeNode),
106107
Sum(SumNode),
107108
Unary(UnaryNode),
@@ -139,6 +140,7 @@ macro_rules! match_all {
139140
Node::MaxPool2d(node) => $func(node),
140141
Node::Range(node) => $func(node),
141142
Node::Reshape(node) => $func(node),
143+
Node::Slice(node) => $func(node),
142144
Node::Squeeze(node) => $func(node),
143145
Node::Sum(node) => $func(node),
144146
Node::Unary(node) => $func(node),
@@ -186,6 +188,7 @@ impl<PS: PrecisionSettings> Node<PS> {
186188
Node::MaxPool2d(_) => "max_pool2d",
187189
Node::Range(_) => "range",
188190
Node::Reshape(_) => "reshape",
191+
Node::Slice(_) => "slice",
189192
Node::Squeeze(_) => "squeeze",
190193
Node::Sum(_) => "add",
191194
Node::Unary(unary) => unary.kind.as_str(),

crates/burn-import/src/burn/node/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub(crate) mod random_normal;
2727
pub(crate) mod random_uniform;
2828
pub(crate) mod range;
2929
pub(crate) mod reshape;
30+
pub(crate) mod slice;
3031
pub(crate) mod squeeze;
3132
pub(crate) mod sum;
3233
pub(crate) mod unary;
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
use super::{Node, NodeCodegen};
2+
use crate::burn::{Scope, TensorType, Type};
3+
use burn::record::PrecisionSettings;
4+
use proc_macro2::TokenStream;
5+
use quote::quote;
6+
7+
#[derive(Debug, Clone, new)]
8+
pub struct SliceNode {
9+
pub input: TensorType,
10+
pub output: TensorType,
11+
pub starts: Vec<usize>,
12+
pub ends: Vec<usize>,
13+
// pub axes: Option<Vec<i64>>,
14+
// pub steps: Option<Vec<i64>>,
15+
}
16+
17+
impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
18+
fn output_types(&self) -> Vec<Type> {
19+
vec![Type::Tensor(self.output.clone())]
20+
}
21+
fn input_types(&self) -> Vec<Type> {
22+
vec![Type::Tensor(self.input.clone())]
23+
}
24+
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
25+
let input = scope.tensor_use_owned(&self.input, node_position);
26+
let output = &self.output.name;
27+
// let axes = match &self.axes {
28+
// Some(axes) => axes.to_tokens(),
29+
// None => quote! { None },
30+
// };
31+
// let steps = match &self.steps {
32+
// Some(steps) => steps.to_tokens(),
33+
// None => quote! { None },
34+
// };
35+
36+
// compine starts and ends into ranges
37+
// example: starts = [0, 0, 0, 0], ends = [1, 1, 1, 1]
38+
// ranges = [0..1, 0..1, 0..1, 0..1]
39+
let starts = &self.starts;
40+
let ends = &self.ends;
41+
42+
quote! {
43+
let #output = #input.slice([#(#starts..#ends),*]);
44+
}
45+
}
46+
fn into_node(self) -> Node<PS> {
47+
Node::Slice(self)
48+
}
49+
}
50+
51+
#[cfg(test)]
52+
mod tests {
53+
use burn::record::FullPrecisionSettings;
54+
55+
use super::*;
56+
use crate::burn::{
57+
graph::BurnGraph,
58+
node::{slice::SliceNode, test::assert_tokens},
59+
TensorType,
60+
};
61+
62+
#[test]
63+
fn test_codegen_slice() {
64+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
65+
graph.register(SliceNode::new(
66+
TensorType::new_float("tensor1", 4),
67+
TensorType::new_float("tensor2", 4),
68+
vec![0, 0, 0, 0],
69+
vec![1, 1, 1, 1],
70+
));
71+
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
72+
73+
let expected = quote! {
74+
use burn::{
75+
module::Module,
76+
tensor::{backend::Backend, Tensor},
77+
};
78+
79+
#[derive(Module, Debug)]
80+
pub struct Model<B: Backend> {
81+
phantom: core::marker::PhantomData<B>,
82+
device: burn::module::Ignored<B::Device>,
83+
}
84+
85+
impl<B: Backend> Model <B> {
86+
#[allow(unused_variables)]
87+
pub fn new(device: &B::Device) -> Self {
88+
Self {
89+
phantom: core::marker::PhantomData,
90+
device: burn::module::Ignored(device.clone()),
91+
}
92+
}
93+
#[allow(clippy::let_and_return, clippy::approx_constant)]
94+
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
95+
let tensor2 = tensor1.slice([0usize..1usize,0usize..1usize,0usize..1usize,0usize..1usize]);
96+
97+
tensor2
98+
}
99+
}
100+
};
101+
102+
assert_tokens(graph.codegen(), expected);
103+
}
104+
}

crates/burn-import/src/onnx/dim_inference.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
6464
NodeType::Sigmoid => same_as_input(node),
6565
NodeType::Sign => same_as_input(node),
6666
NodeType::Sin => same_as_input(node),
67+
NodeType::Slice => slice_update_outputs(node),
6768
NodeType::Softmax => same_as_input(node),
6869
NodeType::Sqrt => same_as_input(node),
6970
NodeType::Sub => same_as_input(node),
@@ -426,6 +427,33 @@ fn squeeze_update_output(node: &mut Node) {
426427
});
427428
}
428429

430+
fn slice_update_outputs(node: &mut Node) {
431+
let shape = match &node.inputs[1].value {
432+
Some(value) => match value {
433+
Data::Int64s(shape) => Some(shape.clone()),
434+
_ => panic!("Slice: invalid input types"),
435+
},
436+
None => None,
437+
};
438+
439+
if shape.is_none() {
440+
panic!("Slice: invalid shape");
441+
}
442+
443+
let output = match &node.outputs[0].ty {
444+
ArgType::Tensor(tensor) => tensor.clone(),
445+
_ => panic!("Slice: invalid output types"),
446+
};
447+
448+
if let Some(shape) = shape {
449+
node.outputs[0].ty = ArgType::Tensor(TensorType {
450+
dim: shape.len(),
451+
shape: None, // shape is calculated at runtime
452+
..output
453+
});
454+
}
455+
}
456+
429457
/// Update the output tensor dimension based on the "axes" attribute or the second input
430458
fn unsqueeze_update_output(node: &mut Node) {
431459
let axes = if node.inputs.len() == 2 {

crates/burn-import/src/onnx/from_onnx.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ use super::ir::{ArgType, Argument, Node, NodeType};
1717

1818
use protobuf::Message;
1919

20-
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
20+
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 11] = [
2121
NodeType::BatchNormalization,
2222
NodeType::Clip,
2323
NodeType::Conv1d,
@@ -27,6 +27,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
2727
NodeType::Reshape,
2828
NodeType::Unsqueeze,
2929
NodeType::ReduceSum,
30+
NodeType::Slice,
3031
NodeType::Squeeze,
3132
];
3233

0 commit comments

Comments
 (0)