Skip to content

Commit 6dc886f

Browse files
committed
fix tests
1 parent 7e05956 commit 6dc886f

File tree

1 file changed

+7
-16
lines changed
  • crates/burn-import/src/burn/node

1 file changed

+7
-16
lines changed

crates/burn-import/src/burn/node/prelu.rs

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use super::{Node, NodeCodegen, SerializationBackend};
2-
use crate::burn::{BurnImports, OtherType, Scope, TensorType, ToTokens, Type};
2+
use crate::burn::{BurnImports, OtherType, Scope, TensorType, Type};
33
use burn::{
44
module::{Param, ParamId},
55
nn::{PReluConfig, PReluRecord},
@@ -55,11 +55,8 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
5555

5656
fn field_init(&self) -> Option<TokenStream> {
5757
let name = &self.field.name;
58-
59-
let num_parameters = self.config.num_parameters.to_tokens();
60-
let alpha = self.config.alpha.to_tokens();
6158
let tokens = quote! {
62-
let #name = PReluConfig::new(#num_parameters, #alpha)
59+
let #name = PReluConfig::new()
6360
.init(device);
6461
};
6562

@@ -101,14 +98,8 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
10198
#[cfg(test)]
10299
mod tests {
103100
use super::*;
104-
use crate::burn::{
105-
graph::BurnGraph,
106-
node::{conv1d::Conv1dNode, test::assert_tokens},
107-
TensorType,
108-
};
109-
use burn::{
110-
nn::conv::Conv1dConfig, nn::PaddingConfig1d, record::FullPrecisionSettings, tensor::Data,
111-
};
101+
use crate::burn::{graph::BurnGraph, node::test::assert_tokens, TensorType};
102+
use burn::{record::FullPrecisionSettings, tensor::Data};
112103

113104
#[test]
114105
fn test_codegen() {
@@ -125,8 +116,8 @@ mod tests {
125116
graph.register_input_output(vec!["input".to_string()], vec!["output".to_string()]);
126117

127118
let expected = quote! {
128-
use burn::nn::prelu::PRelu;
129-
use burn::nn::prelu::PReluConfig;
119+
use burn::nn::PRelu;
120+
use burn::nn::PReluConfig;
130121
use burn::{
131122
module::Module,
132123
tensor::{backend::Backend, Tensor},
@@ -140,7 +131,7 @@ mod tests {
140131
impl<B: Backend> Model<B> {
141132
#[allow(unused_variables)]
142133
pub fn new(device: &B::Device) -> Self {
143-
let prelu = PReluConfig::new(1, 0.25).init(device);
134+
let prelu = PReluConfig::new().init(device);
144135
Self {
145136
prelu,
146137
phantom: core::marker::PhantomData,

0 commit comments

Comments
 (0)