Skip to content

Commit e360bdf

Browse files
committed
bug fix
1 parent 8d226c7 commit e360bdf

File tree

3 files changed

+2
-18
lines changed

3 files changed

+2
-18
lines changed

crates/burn-import/src/burn/node/prelu.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
8989
}
9090
}
9191
fn register_imports(&self, imports: &mut BurnImports) {
92-
imports.register("burn::nn::PRelu");
9392
imports.register("burn::nn::prelu::PRelu");
9493
imports.register("burn::nn::prelu::PReluConfig");
9594
}

crates/burn-import/src/onnx/op_configuration.rs

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,6 @@ pub fn max_pool2d_config(curr: &Node) -> MaxPool2dConfig {
120120
.with_padding(padding)
121121
.with_dilation([dilations[0] as usize, dilations[1] as usize])
122122
}
123-
pub fn prelu_config(curr: &Node) -> PReluConfig {
124-
let mut alpha = 0.01;
125-
let mut num_parameters = 0;
126-
for (key, value) in curr.attrs.iter() {
127-
match key.as_str() {
128-
"alpha" => alpha = value.clone().into_f32(),
129-
"num_parameters" => num_parameters = value.clone().into_i32(),
130-
_ => {}
131-
}
132-
}
133-
134-
PReluConfig::new()
135-
.with_num_parameters(num_parameters as usize)
136-
.with_alpha(alpha as f64)
137-
}
138-
139123
pub fn conv_transpose2d_config(curr: &Node) -> ConvTranspose2dConfig {
140124
let mut attrs = curr.attrs.clone();
141125
let kernel_shape = attrs

crates/burn-import/src/onnx/to_burn.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::{
55
};
66

77
use burn::{
8+
nn::PReluConfig,
89
record::{FullPrecisionSettings, HalfPrecisionSettings, PrecisionSettings},
910
tensor::{DataSerialize, Element},
1011
};
@@ -701,7 +702,7 @@ impl OnnxGraph {
701702
let input = node.inputs.first().unwrap().to_tensor_type();
702703
let output = node.outputs.first().unwrap().to_tensor_type();
703704
let weight = extract_data_serialize::<PS::FloatElem>(1, &node).unwrap();
704-
let config = prelu_config(&node);
705+
let config = PReluConfig::new();
705706
let name = &node.name;
706707
PReluNode::<PS>::new(name, input, output, weight, config)
707708
}

0 commit comments

Comments
 (0)