Skip to content

Commit c33cd5b

Browse files
committed
Add ExpandAs module
1 parent 88cbe99 commit c33cd5b

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

ExpandAs.lua

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
ExpandAs, parent = torch.class('nn.ExpandAs', 'nn.Module')
2+
-- expands the second input to match the first
3+
4+
function ExpandAs:__init()
5+
parent.__init(self)
6+
self.output = {}
7+
self.gradInput = {}
8+
9+
self.sum1 = torch.Tensor()
10+
self.sum2 = torch.Tensor()
11+
end
12+
13+
function ExpandAs:updateOutput(input)
14+
self.output[1] = input[1]
15+
self.output[2] = input[2]:expandAs(input[1])
16+
return self.output
17+
end
18+
19+
function ExpandAs:updateGradInput(input, gradOutput)
20+
local b, db = input[2], gradOutput[2]
21+
local s1, s2 = self.sum1, self.sum2
22+
local sumSrc, sumDst = db, s1
23+
24+
for i=1,b:dim() do
25+
if b:size(i) ~= db:size(i) then
26+
sumDst:sum(sumSrc, i)
27+
sumSrc = sumSrc == s1 and s2 or s1
28+
sumDst = sumDst == s1 and s2 or s1
29+
end
30+
end
31+
32+
self.gradInput[1] = gradOutput[1]
33+
self.gradInput[2] = sumSrc
34+
35+
return self.gradInput
36+
end

0 commit comments

Comments
 (0)