File tree Expand file tree Collapse file tree 1 file changed +36
-0
lines changed Expand file tree Collapse file tree 1 file changed +36
-0
lines changed Original file line number Diff line number Diff line change
1
+ ExpandAs , parent = torch .class (' nn.ExpandAs' , ' nn.Module' )
2
+ -- expands the second input to match the first
3
+
4
+ function ExpandAs :__init ()
5
+ parent .__init (self )
6
+ self .output = {}
7
+ self .gradInput = {}
8
+
9
+ self .sum1 = torch .Tensor ()
10
+ self .sum2 = torch .Tensor ()
11
+ end
12
+
13
+ function ExpandAs :updateOutput (input )
14
+ self .output [1 ] = input [1 ]
15
+ self .output [2 ] = input [2 ]:expandAs (input [1 ])
16
+ return self .output
17
+ end
18
+
19
+ function ExpandAs :updateGradInput (input , gradOutput )
20
+ local b , db = input [2 ], gradOutput [2 ]
21
+ local s1 , s2 = self .sum1 , self .sum2
22
+ local sumSrc , sumDst = db , s1
23
+
24
+ for i = 1 ,b :dim () do
25
+ if b :size (i ) ~= db :size (i ) then
26
+ sumDst :sum (sumSrc , i )
27
+ sumSrc = sumSrc == s1 and s2 or s1
28
+ sumDst = sumDst == s1 and s2 or s1
29
+ end
30
+ end
31
+
32
+ self .gradInput [1 ] = gradOutput [1 ]
33
+ self .gradInput [2 ] = sumSrc
34
+
35
+ return self .gradInput
36
+ end
You can’t perform that action at this time.
0 commit comments