1
1
use super :: { Node , NodeCodegen , SerializationBackend } ;
2
- use crate :: burn:: { BurnImports , OtherType , Scope , TensorType , ToTokens , Type } ;
2
+ use crate :: burn:: { BurnImports , OtherType , Scope , TensorType , Type } ;
3
3
use burn:: {
4
4
module:: { Param , ParamId } ,
5
5
nn:: { PReluConfig , PReluRecord } ,
@@ -55,11 +55,8 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
55
55
56
56
fn field_init ( & self ) -> Option < TokenStream > {
57
57
let name = & self . field . name ;
58
-
59
- let num_parameters = self . config . num_parameters . to_tokens ( ) ;
60
- let alpha = self . config . alpha . to_tokens ( ) ;
61
58
let tokens = quote ! {
62
- let #name = PReluConfig :: new( #num_parameters , #alpha )
59
+ let #name = PReluConfig :: new( )
63
60
. init( device) ;
64
61
} ;
65
62
@@ -101,14 +98,8 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for PReluNode<PS> {
101
98
#[ cfg( test) ]
102
99
mod tests {
103
100
use super :: * ;
104
- use crate :: burn:: {
105
- graph:: BurnGraph ,
106
- node:: { conv1d:: Conv1dNode , test:: assert_tokens} ,
107
- TensorType ,
108
- } ;
109
- use burn:: {
110
- nn:: conv:: Conv1dConfig , nn:: PaddingConfig1d , record:: FullPrecisionSettings , tensor:: Data ,
111
- } ;
101
+ use crate :: burn:: { graph:: BurnGraph , node:: test:: assert_tokens, TensorType } ;
102
+ use burn:: { record:: FullPrecisionSettings , tensor:: Data } ;
112
103
113
104
#[ test]
114
105
fn test_codegen ( ) {
@@ -125,8 +116,8 @@ mod tests {
125
116
graph. register_input_output ( vec ! [ "input" . to_string( ) ] , vec ! [ "output" . to_string( ) ] ) ;
126
117
127
118
let expected = quote ! {
128
- use burn:: nn:: prelu :: PRelu ;
129
- use burn:: nn:: prelu :: PReluConfig ;
119
+ use burn:: nn:: PRelu ;
120
+ use burn:: nn:: PReluConfig ;
130
121
use burn:: {
131
122
module:: Module ,
132
123
tensor:: { backend:: Backend , Tensor } ,
@@ -140,7 +131,7 @@ mod tests {
140
131
impl <B : Backend > Model <B > {
141
132
#[ allow( unused_variables) ]
142
133
pub fn new( device: & B :: Device ) -> Self {
143
- let prelu = PReluConfig :: new( 1 , 0.25 ) . init( device) ;
134
+ let prelu = PReluConfig :: new( ) . init( device) ;
144
135
Self {
145
136
prelu,
146
137
phantom: core:: marker:: PhantomData ,
0 commit comments