Skip to content

Commit 671ec8c

Browse files
authored
feat: added slice onnx import (#1856)
* feat: added slice onnx import * fix: axes, steps handling
1 parent dd60446 commit 671ec8c

File tree

12 files changed

+318
-3
lines changed

12 files changed

+318
-3
lines changed

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ represent the corresponding Burn Op.
171171
| [Sin][164] |||
172172
| [Sinh][165] |||
173173
| [Size][166] |||
174-
| [Slice][167] | ||
174+
| [Slice][167] | ||
175175
| [Softmax][168] |||
176176
| [SoftmaxCrossEntropyLoss][169] |||
177177
| [Softplus][170] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ fn main() {
6969
.input("tests/conv_transpose2d/conv_transpose2d.onnx")
7070
.input("tests/pow/pow.onnx")
7171
.input("tests/pow/pow_int.onnx")
72+
.input("tests/slice/slice.onnx")
7273
.input("tests/sum/sum.onnx")
7374
.input("tests/sum/sum_int.onnx")
7475
.input("tests/unsqueeze/unsqueeze.onnx")

crates/burn-import/onnx-tests/tests/onnx_tests.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ include_models!(
7171
sigmoid,
7272
sign,
7373
sin,
74+
slice,
7475
softmax,
7576
sqrt,
7677
sub_int,
@@ -459,6 +460,24 @@ mod tests {
459460
assert!(expected_sum_2d.approx_eq(output_sum_2d, (1.0e-4, 2)));
460461
}
461462

463+
#[test]
464+
fn slice() {
465+
let model: slice::Model<Backend> = slice::Model::default();
466+
let device = Default::default();
467+
468+
let input = Tensor::<Backend, 2>::from_floats(
469+
[
470+
[1., 2., 3., 4., 5., 6., 7., 8., 9., 10.],
471+
[11., 12., 13., 14., 15., 16., 17., 18., 19., 20.],
472+
],
473+
&device,
474+
);
475+
let output = model.forward(input);
476+
let expected = Data::from([[1., 2., 3., 4., 5.]]);
477+
478+
assert_eq!(output.to_data(), expected);
479+
}
480+
462481
#[test]
463482
fn softmax() {
464483
// Initialize the model without weights (because the exported file does not contain them)
Binary file not shown.
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: onnx-tests/tests/slice/slice.onnx
4+
5+
import onnx
6+
from onnx import helper, TensorProto
7+
8+
def main() -> None:
9+
# Starts
10+
starts_val = [0,0] # Example shape value
11+
starts_tensor = helper.make_tensor(
12+
name="starts",
13+
data_type=TensorProto.INT64,
14+
dims=[len(starts_val)],
15+
vals=starts_val,
16+
)
17+
starts_node = helper.make_node(
18+
"Constant",
19+
name="starts_constant",
20+
inputs=[],
21+
outputs=["starts"],
22+
value=starts_tensor,
23+
)
24+
25+
# Ends
26+
ends_val = [1,5] # Example shape value
27+
ends_tensor = helper.make_tensor(
28+
name="ends",
29+
data_type=TensorProto.INT64,
30+
dims=[len(ends_val)],
31+
vals=ends_val,
32+
)
33+
ends_node = helper.make_node(
34+
"Constant",
35+
name="ends_constant",
36+
inputs=[],
37+
outputs=["ends"],
38+
value=ends_tensor,
39+
)
40+
41+
# Axes
42+
axes_val = [0,1] # Example shape value
43+
axes_tensor = helper.make_tensor(
44+
name="axes",
45+
data_type=TensorProto.INT64,
46+
dims=[len(axes_val)],
47+
vals=axes_val,
48+
)
49+
axes_node = helper.make_node(
50+
"Constant",
51+
name="axes_constant",
52+
inputs=[],
53+
outputs=["axes"],
54+
value=axes_tensor,
55+
)
56+
57+
# Steps
58+
steps_val = [1, 1] # Example shape value
59+
steps_tensor = helper.make_tensor(
60+
name="steps",
61+
data_type=TensorProto.INT64,
62+
dims=[len(steps_val)],
63+
vals=steps_val,
64+
)
65+
steps_node = helper.make_node(
66+
"Constant",
67+
name="steps_constant",
68+
inputs=[],
69+
outputs=["steps"],
70+
value=steps_tensor,
71+
)
72+
73+
# Define the Slice node that uses the outputs from the constant nodes
74+
slice_node = helper.make_node(
75+
"Slice",
76+
name="slice_node",
77+
inputs=["input_tensor", "starts", "ends", "axes", "steps"],
78+
outputs=["output"],
79+
)
80+
81+
# Create the graph
82+
graph_def = helper.make_graph(
83+
nodes=[starts_node, ends_node, axes_node, steps_node, slice_node],
84+
name="SliceGraph",
85+
inputs=[
86+
helper.make_tensor_value_info("input_tensor", TensorProto.FLOAT, [2, 10]),
87+
],
88+
outputs=[
89+
helper.make_tensor_value_info("output", TensorProto.FLOAT, [1, 5])
90+
],
91+
)
92+
93+
# Create the model
94+
model_def = helper.make_model(graph_def, producer_name="slice")
95+
96+
# Save the model to a file
97+
onnx.save(model_def, "slice.onnx")
98+
99+
100+
if __name__ == "__main__":
101+
main()

crates/burn-import/src/burn/node/base.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use super::{
77
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
88
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode,
99
random_normal::RandomNormalNode, random_uniform::RandomUniformNode, range::RangeNode,
10-
reshape::ReshapeNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
10+
reshape::ReshapeNode, slice::SliceNode, squeeze::SqueezeNode, sum::SumNode, unary::UnaryNode,
1111
unsqueeze::UnsqueezeNode,
1212
};
1313
use crate::burn::{BurnImports, Scope, Type};
@@ -102,6 +102,7 @@ pub enum Node<PS: PrecisionSettings> {
102102
MaxPool2d(MaxPool2dNode),
103103
Range(RangeNode),
104104
Reshape(ReshapeNode),
105+
Slice(SliceNode),
105106
Squeeze(SqueezeNode),
106107
Sum(SumNode),
107108
Unary(UnaryNode),
@@ -139,6 +140,7 @@ macro_rules! match_all {
139140
Node::MaxPool2d(node) => $func(node),
140141
Node::Range(node) => $func(node),
141142
Node::Reshape(node) => $func(node),
143+
Node::Slice(node) => $func(node),
142144
Node::Squeeze(node) => $func(node),
143145
Node::Sum(node) => $func(node),
144146
Node::Unary(node) => $func(node),
@@ -186,6 +188,7 @@ impl<PS: PrecisionSettings> Node<PS> {
186188
Node::MaxPool2d(_) => "max_pool2d",
187189
Node::Range(_) => "range",
188190
Node::Reshape(_) => "reshape",
191+
Node::Slice(_) => "slice",
189192
Node::Squeeze(_) => "squeeze",
190193
Node::Sum(_) => "add",
191194
Node::Unary(unary) => unary.kind.as_str(),

crates/burn-import/src/burn/node/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ pub(crate) mod random_normal;
2727
pub(crate) mod random_uniform;
2828
pub(crate) mod range;
2929
pub(crate) mod reshape;
30+
pub(crate) mod slice;
3031
pub(crate) mod squeeze;
3132
pub(crate) mod sum;
3233
pub(crate) mod unary;
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
use super::{Node, NodeCodegen};
2+
use crate::burn::{Scope, TensorType, Type};
3+
use burn::record::PrecisionSettings;
4+
use proc_macro2::TokenStream;
5+
use quote::quote;
6+
7+
#[derive(Debug, Clone, new)]
8+
pub struct SliceNode {
9+
pub input: TensorType,
10+
pub output: TensorType,
11+
pub starts: Vec<usize>,
12+
pub ends: Vec<usize>,
13+
}
14+
15+
impl<PS: PrecisionSettings> NodeCodegen<PS> for SliceNode {
16+
fn output_types(&self) -> Vec<Type> {
17+
vec![Type::Tensor(self.output.clone())]
18+
}
19+
fn input_types(&self) -> Vec<Type> {
20+
vec![Type::Tensor(self.input.clone())]
21+
}
22+
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
23+
let input = scope.tensor_use_owned(&self.input, node_position);
24+
let output = &self.output.name;
25+
let starts = &self.starts;
26+
let ends = &self.ends;
27+
28+
quote! {
29+
let #output = #input.slice([#(#starts..#ends),*]);
30+
}
31+
}
32+
fn into_node(self) -> Node<PS> {
33+
Node::Slice(self)
34+
}
35+
}
36+
37+
#[cfg(test)]
38+
mod tests {
39+
use burn::record::FullPrecisionSettings;
40+
41+
use super::*;
42+
use crate::burn::{
43+
graph::BurnGraph,
44+
node::{slice::SliceNode, test::assert_tokens},
45+
TensorType,
46+
};
47+
48+
#[test]
49+
fn test_codegen_slice() {
50+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
51+
graph.register(SliceNode::new(
52+
TensorType::new_float("tensor1", 4),
53+
TensorType::new_float("tensor2", 4),
54+
vec![0, 0, 0, 0],
55+
vec![1, 1, 1, 1],
56+
));
57+
graph.register_input_output(vec!["tensor1".to_string()], vec!["tensor2".to_string()]);
58+
59+
let expected = quote! {
60+
use burn::{
61+
module::Module,
62+
tensor::{backend::Backend, Tensor},
63+
};
64+
65+
#[derive(Module, Debug)]
66+
pub struct Model<B: Backend> {
67+
phantom: core::marker::PhantomData<B>,
68+
device: burn::module::Ignored<B::Device>,
69+
}
70+
71+
impl<B: Backend> Model <B> {
72+
#[allow(unused_variables)]
73+
pub fn new(device: &B::Device) -> Self {
74+
Self {
75+
phantom: core::marker::PhantomData,
76+
device: burn::module::Ignored(device.clone()),
77+
}
78+
}
79+
#[allow(clippy::let_and_return, clippy::approx_constant)]
80+
pub fn forward(&self, tensor1: Tensor<B, 4>) -> Tensor<B, 4> {
81+
let tensor2 = tensor1.slice([0usize..1usize,0usize..1usize,0usize..1usize,0usize..1usize]);
82+
83+
tensor2
84+
}
85+
}
86+
};
87+
88+
assert_tokens(graph.codegen(), expected);
89+
}
90+
}

crates/burn-import/src/onnx/dim_inference.rs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ pub fn dim_inference(node: &mut Node) {
6363
NodeType::Sigmoid => same_as_input(node),
6464
NodeType::Sign => same_as_input(node),
6565
NodeType::Sin => same_as_input(node),
66+
NodeType::Slice => slice_update_outputs(node),
6667
NodeType::Softmax => same_as_input(node),
6768
NodeType::Sqrt => same_as_input(node),
6869
NodeType::Sub => same_as_input(node),
@@ -423,6 +424,33 @@ fn squeeze_update_output(node: &mut Node) {
423424
});
424425
}
425426

427+
fn slice_update_outputs(node: &mut Node) {
428+
let shape = match &node.inputs[1].value {
429+
Some(value) => match value {
430+
Data::Int64s(shape) => Some(shape.clone()),
431+
_ => panic!("Slice: invalid input types"),
432+
},
433+
None => None,
434+
};
435+
436+
if shape.is_none() {
437+
panic!("Slice: invalid shape");
438+
}
439+
440+
let output = match &node.outputs[0].ty {
441+
ArgType::Tensor(tensor) => tensor.clone(),
442+
_ => panic!("Slice: invalid output types"),
443+
};
444+
445+
if let Some(shape) = shape {
446+
node.outputs[0].ty = ArgType::Tensor(TensorType {
447+
dim: shape.len(),
448+
shape: None, // shape is calculated at runtime
449+
..output
450+
});
451+
}
452+
}
453+
426454
/// Update the output tensor dimension based on the "axes" attribute or the second input
427455
fn unsqueeze_update_output(node: &mut Node) {
428456
let axes = if node.inputs.len() == 2 {

crates/burn-import/src/onnx/from_onnx.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use super::ir::{ArgType, Argument, Node, NodeType};
1818

1919
use protobuf::Message;
2020

21-
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
21+
const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 11] = [
2222
NodeType::BatchNormalization,
2323
NodeType::Clip,
2424
NodeType::Conv1d,
@@ -28,6 +28,7 @@ const LIFT_CONSTANTS_FOR_NODE_TYPES: [NodeType; 10] = [
2828
NodeType::Reshape,
2929
NodeType::Unsqueeze,
3030
NodeType::ReduceSum,
31+
NodeType::Slice,
3132
NodeType::Squeeze,
3233
];
3334

0 commit comments

Comments
 (0)