@@ -4,25 +4,38 @@ function CMaxTable:__init()
4
4
parent .__init (self )
5
5
self .gradInput = {}
6
6
self .maxIdx = torch .Tensor ()
7
+ self .mask = torch .Tensor ()
8
+ self .maxVals = torch .Tensor ()
9
+ self .gradMaxVals = torch .Tensor ()
7
10
end
8
11
9
12
function CMaxTable :updateOutput (input )
10
13
self .output :resizeAs (input [1 ]):copy (input [1 ])
11
14
self .maxIdx :resizeAs (input [1 ]):fill (1 )
12
15
for i = 2 ,# input do
13
- local mask = torch .gt (input [i ], self .output )
14
- self .maxIdx :maskedFill (mask , i )
15
- self .output :maskedCopy (mask , input [i ][mask ])
16
+ self .maskByteTensor = self .maskByteTensor or
17
+ (torch .type (self .output ) == ' torch.CudaTensor' and
18
+ torch .CudaByteTensor () or torch .ByteTensor ())
19
+ self .mask :gt (input [i ], self .output )
20
+ self .maskByteTensor :resize (self .mask :size ()):copy (self .mask )
21
+ self .maxIdx :maskedFill (self .maskByteTensor , i )
22
+ self .maxVals :maskedSelect (input [i ], self .maskByteTensor )
23
+ self .output :maskedCopy (self .maskByteTensor , self .maxVals )
16
24
end
17
25
return self .output
18
26
end
19
27
20
28
function CMaxTable :updateGradInput (input , gradOutput )
21
29
for i = 1 ,# input do
22
- self .gradInput [i ] = input [i ].new ()
30
+ self .gradInput [i ] = self . gradInput [ i ] or input [i ].new ()
23
31
self .gradInput [i ]:resizeAs (input [i ]):fill (0.0 )
24
- local mask = torch .eq (self .maxIdx , i )
25
- self .gradInput [i ]:maskedCopy (mask , gradOutput [mask ])
32
+ self .maskByteTensor = self .maskByteTensor or
33
+ (torch .type (self .output ) == ' torch.CudaTensor' and
34
+ torch .CudaByteTensor () or torch .ByteTensor ())
35
+ self .mask :eq (self .maxIdx , i )
36
+ self .maskByteTensor :resize (self .mask :size ()):copy (self .mask )
37
+ self .gradMaxVals :maskedSelect (gradOutput , self .maskByteTensor )
38
+ self .gradInput [i ]:maskedCopy (self .maskByteTensor , self .gradMaxVals )
26
39
end
27
40
28
41
for i =# input + 1 , # self .gradInput do
0 commit comments