forked from Element-Research/rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathRepeaterCriterion.lua
37 lines (33 loc) · 1.27 KB
/
RepeaterCriterion.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
------------------------------------------------------------------------
--[[ RepeaterCriterion ]]--
-- Applies a criterion to each of the inputs in a Table using the
-- same target (the target is repeated).
-- Useful for nn.Repeater and nn.Sequencer.
------------------------------------------------------------------------
assert(not nn.RepeaterCriterion, "update nnx package : luarocks install nnx")
local RepeaterCriterion, parent = torch.class('nn.RepeaterCriterion', 'nn.Criterion')
function RepeaterCriterion:__init(criterion)
parent.__init(self)
self.criterion = criterion
self.gradInput = {}
end
function RepeaterCriterion:forward(inputTable, target)
self.output = 0
for i,input in ipairs(inputTable) do
self.output = self.output + self.criterion:forward(input, target)
end
return self.output
end
function RepeaterCriterion:backward(inputTable, target)
for i,input in ipairs(inputTable) do
self.gradInput[i] = nn.rnn.recursiveCopy(self.gradInput[i], self.criterion:backward(input, target))
end
for i = #inputTable+1, #self.gradInput do
self.gradInput[i] = nil
end
return self.gradInput
end
function RepeaterCriterion:type(type, ...)
self.gradInput = nn.rnn.recursiveType(self.gradInput)
return self.criterion:type(type, ...)
end