Skip to content

Commit b9c086e

Browse files
committed
added onxx avgpool1d
1 parent bd06b38 commit b9c086e

File tree

12 files changed

+320
-11
lines changed

12 files changed

+320
-11
lines changed

crates/burn-core/src/nn/pool/avg_pool1d.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use crate::tensor::Tensor;
88
use burn_tensor::module::avg_pool1d;
99

1010
/// Configuration to create a [1D avg pooling](AvgPool1d) layer.
11-
#[derive(Config)]
11+
#[derive(Config, Debug)]
1212
pub struct AvgPool1dConfig {
1313
/// The size of the kernel.
1414
pub kernel_size: usize,
@@ -20,7 +20,7 @@ pub struct AvgPool1dConfig {
2020
pub padding: PaddingConfig1d,
2121
/// If the padding is counted in the denominator when computing the average.
2222
#[config(default = "true")]
23-
count_include_pad: bool,
23+
pub count_include_pad: bool,
2424
}
2525

2626
/// Applies a 1D avg pooling over input tensors.

crates/burn-import/SUPPORTED-ONNX-OPS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ represent the corresponding Burn Op.
1717
| [Asinh][9] |||
1818
| [Atan][10] |||
1919
| [Atanh][11] |||
20-
| [AveragePool1d][12] | ||
20+
| [AveragePool1d][12] | ||
2121
| [AveragePool2d][12] |||
2222
| [BatchNormalization][14] |||
2323
| [Bernoulli][15] |||

crates/burn-import/onnx-tests/build.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ fn main() {
88
ModelGen::new()
99
.input("tests/add/add_int.onnx")
1010
.input("tests/add/add.onnx")
11+
.input("tests/avg_pool1d/avg_pool1d.onnx")
1112
.input("tests/avg_pool2d/avg_pool2d.onnx")
1213
.input("tests/batch_norm/batch_norm.onnx")
1314
.input("tests/cast/cast.onnx")
Binary file not shown.
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
#!/usr/bin/env python3
2+
3+
# used to generate model: avg_pool1d.onnx
4+
5+
import torch
6+
import torch.nn as nn
7+
8+
9+
class Model(nn.Module):
10+
def __init__(self):
11+
super(Model, self).__init__()
12+
13+
self.pool1 = nn.AvgPool1d(4, stride=2)
14+
15+
self.pool2 = nn.AvgPool1d(4, stride=2, padding=2, count_include_pad=True)
16+
17+
self.pool3 = nn.AvgPool1d(4, stride=2, padding=2, count_include_pad=False)
18+
19+
def forward(self, x1, x2, x3):
20+
y1 = self.pool1(x1)
21+
y2 = self.pool2(x2)
22+
y3 = self.pool3(x3)
23+
return y1, y2, y3
24+
25+
26+
def main():
27+
# Set seed for reproducibility
28+
torch.manual_seed(1)
29+
30+
# Print options
31+
torch.set_printoptions(precision=3)
32+
33+
# Export to onnx
34+
model = Model()
35+
model.eval()
36+
device = torch.device("cpu")
37+
38+
file_name = "avg_pool1d.onnx"
39+
input1 = torch.randn(1, 5, 5, device=device)
40+
torch.onnx.export(model, (input1, input1, input1), file_name,
41+
verbose=False, opset_version=16)
42+
43+
print("Finished exporting model to {}".format(file_name))
44+
45+
# Output some test data for use in the test
46+
print("Test input data shape: {}".format(input1.shape))
47+
print("Test input data: {}".format(input1))
48+
output1, output2, output3 = model.forward(input1, input1, input1)
49+
print("Test output1 data shape: {}".format(output1.shape))
50+
print("Test output2 data shape: {}".format(output2.shape))
51+
print("Test output3 data shape: {}".format(output3.shape))
52+
print("Test output1: {}".format(output1))
53+
print("Test output2: {}".format(output2))
54+
print("Test output3: {}".format(output3))
55+
56+
57+
if __name__ == '__main__':
58+
main()

crates/burn-import/onnx-tests/tests/onnx_tests.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include_models!(
1818
add_int,
1919
add,
2020
avg_pool2d,
21+
avg_pool1d,
2122
batch_norm,
2223
cast,
2324
clip_opset16,
@@ -498,6 +499,53 @@ mod tests {
498499
assert_eq!(output.to_data(), expected);
499500
}
500501

502+
#[test]
503+
fn avg_pool1d() {
504+
// Initialize the model without weights (because the exported file does not contain them)
505+
let device = Default::default();
506+
let model: avg_pool1d::Model<Backend> = avg_pool1d::Model::new(&device);
507+
508+
// Run the model
509+
let input = Tensor::<Backend, 3>::from_floats(
510+
[[
511+
[-1.526, -0.750, -0.654, -1.609, -0.100],
512+
[-0.609, -0.980, -1.609, -0.712, 1.171],
513+
[1.767, -0.095, 0.139, -1.579, -0.321],
514+
[-0.299, 1.879, 0.336, 0.275, 1.716],
515+
[-0.056, 0.911, -1.392, 2.689, -0.111],
516+
]],
517+
&device,
518+
);
519+
let (output1, output2, output3) = model.forward(input.clone(), input.clone(), input);
520+
let expected1 = Data::from([[[-1.135], [-0.978], [0.058], [0.548], [0.538]]]);
521+
let expected2 = Data::from([[
522+
[-0.569, -1.135, -0.591],
523+
[-0.397, -0.978, -0.288],
524+
[0.418, 0.058, -0.440],
525+
[0.395, 0.548, 0.582],
526+
[0.214, 0.538, 0.296],
527+
]]);
528+
let expected3 = Data::from([[
529+
[-1.138, -1.135, -0.788],
530+
[-0.794, -0.978, -0.383],
531+
[0.836, 0.058, -0.587],
532+
[0.790, 0.548, 0.776],
533+
[0.427, 0.538, 0.395],
534+
]]);
535+
536+
let expected_shape1 = Shape::from([1, 5, 1]);
537+
let expected_shape2 = Shape::from([1, 5, 3]);
538+
let expected_shape3 = Shape::from([1, 5, 3]);
539+
540+
assert_eq!(output1.shape(), expected_shape1);
541+
assert_eq!(output2.shape(), expected_shape2);
542+
assert_eq!(output3.shape(), expected_shape3);
543+
544+
output1.to_data().assert_approx_eq(&expected1, 3);
545+
output2.to_data().assert_approx_eq(&expected2, 3);
546+
output3.to_data().assert_approx_eq(&expected3, 3);
547+
}
548+
501549
#[test]
502550
fn avg_pool2d() {
503551
// Initialize the model without weights (because the exported file does not contain them)
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
use proc_macro2::TokenStream;
2+
use quote::quote;
3+
4+
use burn::{nn::pool::AvgPool1dConfig, record::PrecisionSettings};
5+
6+
use super::{Node, NodeCodegen};
7+
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
8+
9+
#[derive(Debug, Clone)]
10+
pub struct AvgPool1dNode {
11+
pub field: OtherType,
12+
pub input: TensorType,
13+
pub output: TensorType,
14+
pub config: AvgPool1dConfig,
15+
}
16+
17+
impl AvgPool1dNode {
18+
pub fn new<S: AsRef<str>>(
19+
name: S,
20+
input: TensorType,
21+
output: TensorType,
22+
config: AvgPool1dConfig,
23+
) -> Self {
24+
Self {
25+
field: OtherType::new(
26+
name,
27+
quote! {
28+
AvgPool1d
29+
},
30+
),
31+
input,
32+
output,
33+
config,
34+
}
35+
}
36+
}
37+
38+
impl<PS: PrecisionSettings> NodeCodegen<PS> for AvgPool1dNode {
39+
fn input_types(&self) -> Vec<Type> {
40+
vec![Type::Tensor(self.input.clone())]
41+
}
42+
fn output_types(&self) -> Vec<Type> {
43+
vec![Type::Tensor(self.output.clone())]
44+
}
45+
fn field_type(&self) -> Option<Type> {
46+
Some(Type::Other(self.field.clone()))
47+
}
48+
49+
fn field_init(&self) -> Option<TokenStream> {
50+
let name = &self.field.name;
51+
let kernel_size = self.config.kernel_size.to_tokens();
52+
let strides = self.config.stride.to_tokens();
53+
let padding = self.config.padding.to_tokens();
54+
let count_include_pad = self.config.count_include_pad;
55+
56+
let tokens = quote! {
57+
let #name = AvgPool1dConfig::new(#kernel_size)
58+
.with_stride(#strides)
59+
.with_padding(#padding)
60+
.with_count_include_pad(#count_include_pad)
61+
.init();
62+
};
63+
64+
Some(tokens)
65+
}
66+
67+
fn forward(&self, scope: &mut Scope, node_position: usize) -> TokenStream {
68+
let input = scope.tensor_use_owned(&self.input, node_position);
69+
let output = &self.output.name;
70+
let field = &self.field.name;
71+
72+
quote! {
73+
let #output = self.#field.forward(#input);
74+
}
75+
}
76+
77+
fn register_imports(&self, imports: &mut BurnImports) {
78+
imports.register("burn::nn::PaddingConfig1d");
79+
imports.register("burn::nn::pool::AvgPool1d");
80+
imports.register("burn::nn::pool::AvgPool1dConfig");
81+
}
82+
83+
fn into_node(self) -> Node<PS> {
84+
Node::AvgPool1d(self)
85+
}
86+
87+
fn field_serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88+
S::serialize_none(serializer)
89+
}
90+
}
91+
92+
#[cfg(test)]
93+
mod tests {
94+
use super::*;
95+
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
96+
use burn::{nn::PaddingConfig1d, record::FullPrecisionSettings};
97+
98+
#[test]
99+
fn test_codegen() {
100+
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
101+
102+
graph.register(AvgPool1dNode::new(
103+
"avg_pool1d",
104+
TensorType::new_float("input", 3),
105+
TensorType::new_float("output", 3),
106+
AvgPool1dConfig::new(3)
107+
.with_stride(1)
108+
.with_padding(PaddingConfig1d::Valid),
109+
));
110+
111+
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
112+
113+
let expected = quote! {
114+
use burn::{
115+
module::Module,
116+
tensor::{backend::Backend, Tensor},
117+
};
118+
use burn::nn::PaddingConfig1d;
119+
use burn::nn::pool::AvgPool1d;
120+
use burn::nn::pool::AvgPool1dConfig;
121+
122+
#[derive(Module, Debug)]
123+
pub struct Model <B: Backend> {
124+
avg_pool1d: AvgPool1d,
125+
phantom: core::marker::PhantomData<B>,
126+
device: burn::module::Ignored<B::Device>,
127+
}
128+
129+
impl<B: Backend> Model <B> {
130+
#[allow(unused_variables)]
131+
pub fn new(device: &B::Device) -> Self {
132+
let avg_pool1d = AvgPool1dConfig::new(3)
133+
.with_stride(1)
134+
.with_padding(PaddingConfig1d::Valid)
135+
.with_count_include_pad(true)
136+
.init();
137+
138+
Self {
139+
avg_pool1d,
140+
phantom: core::marker::PhantomData,
141+
device: burn::module::Ignored(device.clone()),
142+
}
143+
}
144+
#[allow(clippy::let_and_return, clippy::approx_constant)]
145+
pub fn forward(&self, input: Tensor<B, 3>) -> Tensor<B, 3> {
146+
let output = self.avg_pool1d.forward(input);
147+
148+
output
149+
}
150+
}
151+
};
152+
153+
assert_tokens(graph.codegen(), expected);
154+
}
155+
}

crates/burn-import/src/burn/node/base.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use super::{
2-
avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode, binary::BinaryNode, clip::ClipNode,
3-
concat::ConcatNode, constant::ConstantNode, conv1d::Conv1dNode, conv2d::Conv2dNode,
4-
conv_transpose_2d::ConvTranspose2dNode, dropout::DropoutNode, gather::GatherNode,
5-
global_avg_pool::GlobalAvgPoolNode, layer_norm::LayerNormNode, linear::LinearNode,
6-
mask_where::WhereNode, matmul::MatmulNode, max_pool1d::MaxPool1dNode,
7-
max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode, unary::UnaryNode,
8-
unsqueeze::UnsqueezeNode,
2+
avg_pool1d::AvgPool1dNode, avg_pool2d::AvgPool2dNode, batch_norm::BatchNormNode,
3+
binary::BinaryNode, clip::ClipNode, concat::ConcatNode, constant::ConstantNode,
4+
conv1d::Conv1dNode, conv2d::Conv2dNode, conv_transpose_2d::ConvTranspose2dNode,
5+
dropout::DropoutNode, gather::GatherNode, global_avg_pool::GlobalAvgPoolNode,
6+
layer_norm::LayerNormNode, linear::LinearNode, mask_where::WhereNode, matmul::MatmulNode,
7+
max_pool1d::MaxPool1dNode, max_pool2d::MaxPool2dNode, prelu::PReluNode, reshape::ReshapeNode,
8+
unary::UnaryNode, unsqueeze::UnsqueezeNode,
99
};
1010
use crate::burn::{BurnImports, Scope, Type};
1111
use burn::backend::NdArray;
@@ -75,6 +75,7 @@ pub trait NodeCodegen<PS: PrecisionSettings>: std::fmt::Debug {
7575

7676
#[derive(Debug, Clone)]
7777
pub enum Node<PS: PrecisionSettings> {
78+
AvgPool1d(AvgPool1dNode),
7879
AvgPool2d(AvgPool2dNode),
7980
BatchNorm(BatchNormNode<PS>),
8081
Binary(BinaryNode),
@@ -103,6 +104,7 @@ macro_rules! match_all {
103104
($self:expr, $func:expr) => {{
104105
#[allow(clippy::redundant_closure_call)]
105106
match $self {
107+
Node::AvgPool1d(node) => $func(node),
106108
Node::AvgPool2d(node) => $func(node),
107109
Node::BatchNorm(node) => $func(node),
108110
Node::Binary(node) => $func(node),
@@ -141,6 +143,7 @@ impl<PS: PrecisionSettings> Serialize for Node<PS> {
141143
impl<PS: PrecisionSettings> Node<PS> {
142144
pub fn name(&self) -> &str {
143145
match self {
146+
Node::AvgPool1d(_) => "avg_pool1d",
144147
Node::AvgPool2d(_) => "avg_pool2d",
145148
Node::BatchNorm(_) => "batch_norm",
146149
Node::Binary(binary) => binary.binary_type.as_str(),

crates/burn-import/src/burn/node/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mod base;
22

3+
pub(crate) mod avg_pool1d;
34
pub(crate) mod avg_pool2d;
45
pub(crate) mod batch_norm;
56
pub(crate) mod binary;

crates/burn-import/src/onnx/dim_inference.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ use super::{
1414
pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
1515
match node.node_type {
1616
NodeType::Add => same_as_input(node),
17+
NodeType::AveragePool1d => same_as_input(node),
1718
NodeType::AveragePool2d => same_as_input(node),
1819
NodeType::BatchNormalization => same_as_input(node),
1920
NodeType::Cast => cast_update_outputs(node),
@@ -38,6 +39,7 @@ pub fn dim_inference(node: &mut Node, graph_io: &mut OnnxGraphIO) {
3839
NodeType::Log => same_as_input(node),
3940
NodeType::LogSoftmax => same_as_input(node),
4041
NodeType::MatMul => matmul_update_outputs(node),
42+
NodeType::MaxPool1d => same_as_input(node),
4143
NodeType::MaxPool2d => same_as_input(node),
4244
NodeType::Mul => same_as_input(node),
4345
NodeType::Neg => same_as_input(node),

0 commit comments

Comments
 (0)