File tree Expand file tree Collapse file tree 1 file changed +36
-0
lines changed Expand file tree Collapse file tree 1 file changed +36
-0
lines changed Original file line number Diff line number Diff line change 1+ ExpandAs , parent = torch .class (' nn.ExpandAs' , ' nn.Module' )
2+ -- expands the second input to match the first
3+
4+ function ExpandAs :__init ()
5+ parent .__init (self )
6+ self .output = {}
7+ self .gradInput = {}
8+
9+ self .sum1 = torch .Tensor ()
10+ self .sum2 = torch .Tensor ()
11+ end
12+
13+ function ExpandAs :updateOutput (input )
14+ self .output [1 ] = input [1 ]
15+ self .output [2 ] = input [2 ]:expandAs (input [1 ])
16+ return self .output
17+ end
18+
19+ function ExpandAs :updateGradInput (input , gradOutput )
20+ local b , db = input [2 ], gradOutput [2 ]
21+ local s1 , s2 = self .sum1 , self .sum2
22+ local sumSrc , sumDst = db , s1
23+
24+ for i = 1 ,b :dim () do
25+ if b :size (i ) ~= db :size (i ) then
26+ sumDst :sum (sumSrc , i )
27+ sumSrc = sumSrc == s1 and s2 or s1
28+ sumDst = sumDst == s1 and s2 or s1
29+ end
30+ end
31+
32+ self .gradInput [1 ] = gradOutput [1 ]
33+ self .gradInput [2 ] = sumSrc
34+
35+ return self .gradInput
36+ end
You can’t perform that action at this time.
0 commit comments