diff --git a/BiSequencerLM.lua b/BiSequencerLM.lua index deef19d..c32ebf5 100644 --- a/BiSequencerLM.lua +++ b/BiSequencerLM.lua @@ -1,25 +1,25 @@ ------------------------------------------------------------------------ --[[ BiSequencerLM ]]-- --- Encapsulates forward, backward and merge modules. +-- Encapsulates forward, backward and merge modules. -- Input is a sequence (a table) of tensors. -- Output is a sequence (a table) of tensors of the same length. --- Applies a `fwd` rnn instance to the first `N-1` elements in the +-- Applies a `fwd` rnn instance to the first `N-1` elements in the -- sequence in forward order. --- Applies the `bwd` rnn in reverse order to the last `N-1` elements +-- Applies the `bwd` rnn in reverse order to the last `N-1` elements -- (from second-to-last element to first element). --- Note : you shouldn't stack these for language modeling. +-- Note : you shouldn't stack these for language modeling. -- Instead, stack each fwd/bwd seqs and encapsulate these. ------------------------------------------------------------------------ local _ = require 'moses' local BiSequencerLM, parent = torch.class('nn.BiSequencerLM', 'nn.AbstractSequencer') function BiSequencerLM:__init(forward, backward, merge) - + if not torch.isTypeOf(forward, 'nn.Module') then error"BiSequencerLM: expecting nn.Module instance at arg 1" end self.forwardModule = forward - + self.backwardModule = backward if not self.backwardModule then self.backwardModule = forward:clone() @@ -28,7 +28,7 @@ function BiSequencerLM:__init(forward, backward, merge) if not torch.isTypeOf(self.backwardModule, 'nn.Module') then error"BiSequencerLM: expecting nn.Module instance at arg 2" end - + if torch.type(merge) == 'number' then self.mergeModule = nn.JoinTable(1, merge) elseif merge == nil then @@ -38,36 +38,36 @@ function BiSequencerLM:__init(forward, backward, merge) else error"BiSequencerLM: expecting nn.Module or number instance at arg 3" end - + if torch.isTypeOf(self.forwardModule, 'nn.AbstractRecurrent') then self.fwdSeq = nn.Sequencer(self.forwardModule) else -- assumes a nn.Sequencer or stack thereof self.fwdSeq = self.forwardModule end - + if torch.isTypeOf(self.backwardModule, 'nn.AbstractRecurrent') then self.bwdSeq = nn.Sequencer(self.backwardModule) else self.bwdSeq = self.backwardModule end self.mergeSeq = nn.Sequencer(self.mergeModule) - + self._fwd = self.fwdSeq - + self._bwd = nn.Sequential() self._bwd:add(nn.ReverseTable()) self._bwd:add(self.bwdSeq) self._bwd:add(nn.ReverseTable()) - + self._merge = nn.Sequential() self._merge:add(nn.ZipTable()) self._merge:add(self.mergeSeq) - - + + parent.__init(self) - + self.modules = {self._fwd, self._bwd, self._merge} - + self.output = {} self.gradInput = {} end @@ -76,38 +76,38 @@ function BiSequencerLM:updateOutput(input) assert(torch.type(input) == 'table', 'Expecting table at arg 1') local nStep = #input assert(nStep > 1, "Expecting at least 2 elements in table") - + -- forward through fwd and bwd rnn in fwd and reverse order self._fwdOutput = self._fwd:updateOutput(_.first(input, nStep - 1)) self._bwdOutput = self._bwd:updateOutput(_.last(input, nStep - 1)) - + -- empty outputs for k,v in ipairs(self.output) do self.output[k] = nil end - + -- padding for first and last elements of fwd and bwd outputs, respectively - self._firstStep = nn.rnn.recursiveResizeAs(self._firstStep, self._fwdOutput[1]) - nn.rnn.recursiveFill(self._firstStep, 0) - self._lastStep = nn.rnn.recursiveResizeAs(self._lastStep, self._bwdOutput[1]) - nn.rnn.recursiveFill(self._lastStep, 0) - + self._firstStep = nn.utils.recursiveResizeAs(self._firstStep, self._fwdOutput[1]) + nn.utils.recursiveFill(self._firstStep, 0) + self._lastStep = nn.utils.recursiveResizeAs(self._lastStep, self._bwdOutput[1]) + nn.utils.recursiveFill(self._lastStep, 0) + -- { { zeros, fwd1, fwd2, ..., fwdN}, {bwd1, bwd2, ..., bwdN, zeros} } self._mergeInput = {_.clone(self._fwdOutput), _.clone(self._bwdOutput)} table.insert(self._mergeInput[1], 1, self._firstStep) table.insert(self._mergeInput[2], self._lastStep) assert(#self._mergeInput[1] == #self._mergeInput[2]) - + self.output = self._merge:updateOutput(self._mergeInput) - + return self.output end function BiSequencerLM:updateGradInput(input, gradOutput) local nStep = #input - + self._mergeGradInput = self._merge:updateGradInput(self._mergeInput, gradOutput) self._fwdGradInput = self._fwd:updateGradInput(_.first(input, nStep - 1), _.last(self._mergeGradInput[1], nStep - 1)) self._bwdGradInput = self._bwd:updateGradInput(_.last(input, nStep - 1), _.first(self._mergeGradInput[2], nStep - 1)) - + -- add fwd rnn gradInputs to bwd rnn gradInputs for i=1,nStep do if i == 1 then @@ -115,17 +115,17 @@ function BiSequencerLM:updateGradInput(input, gradOutput) elseif i == nStep then self.gradInput[nStep] = self._bwdGradInput[nStep-1] else - self.gradInput[i] = nn.rnn.recursiveCopy(self.gradInput[i], self._fwdGradInput[i]) - nn.rnn.recursiveAdd(self.gradInput[i], self._bwdGradInput[i-1]) + self.gradInput[i] = nn.utils.recursiveCopy(self.gradInput[i], self._fwdGradInput[i]) + nn.utils.recursiveAdd(self.gradInput[i], self._bwdGradInput[i-1]) end end - + return self.gradInput end function BiSequencerLM:accGradParameters(input, gradOutput, scale) local nStep = #input - + self._merge:accGradParameters(self._mergeInput, gradOutput, scale) self._fwd:accGradParameters(_.first(input, nStep - 1), _.last(self._mergeGradInput[1], nStep - 1), scale) self._bwd:accGradParameters(_.last(input, nStep - 1), _.first(self._mergeGradInput[2], nStep - 1), scale) @@ -133,7 +133,7 @@ end function BiSequencerLM:accUpdateGradParameters(input, gradOutput, lr) local nStep = #input - + self._merge:accUpdateGradParameters(self._mergeInput, gradOutput, lr) self._fwd:accUpdateGradParameters(_.first(input, nStep - 1), _.last(self._mergeGradInput[1], nStep - 1), lr) self._bwd:accUpdateGradParameters(_.last(input, nStep - 1), _.first(self._mergeGradInput[2], nStep - 1), lr) diff --git a/FireModule.lua b/FireModule.lua index c927c23..f4e583e 100644 --- a/FireModule.lua +++ b/FireModule.lua @@ -16,7 +16,7 @@ function FireModule:__init(nInputPlane, s1x1, e1x1, e3x3, activation) print('Warning: s1x1 is recommended to be smaller'.. ' then e1x1+e3x3') end - + self.module = nn.Sequential() self.squeeze = nn.SpatialConvolution(nInputPlane, s1x1, 1, 1) self.expand = nn.Concat(2) @@ -28,7 +28,7 @@ function FireModule:__init(nInputPlane, s1x1, e1x1, e3x3, activation) self.module:add(nn[self.activation]()) self.module:add(self.expand) self.module:add(nn[self.activation]()) - + Parent.__init(self, self.module) end diff --git a/GRU.lua b/GRU.lua index 97935b0..460b24d 100644 --- a/GRU.lua +++ b/GRU.lua @@ -154,8 +154,8 @@ function GRU:_updateGradInput(input, gradOutput) -- backward propagate through this step local _gradOutput = self:getGradHiddenState(step, input) assert(_gradOutput) - self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], _gradOutput) - nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput) + self._gradOutputs[step] = nn.utils.recursiveCopy(self._gradOutputs[step], _gradOutput) + nn.utils.recursiveAdd(self._gradOutputs[step], gradOutput) gradOutput = self._gradOutputs[step] local gradInputTable = stepmodule:updateGradInput({input, self:getHiddenState(step-1)}, gradOutput) diff --git a/Recurrence.lua b/Recurrence.lua index 9e200bf..8ef0ef5 100644 --- a/Recurrence.lua +++ b/Recurrence.lua @@ -102,8 +102,8 @@ function Recurrence:_updateGradInput(input, gradOutput) -- backward propagate through this step local _gradOutput = self:getGradHiddenState(step, input)[1] - self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], _gradOutput) - nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput) + self._gradOutputs[step] = nn.utils.recursiveCopy(self._gradOutputs[step], _gradOutput) + nn.utils.recursiveAdd(self._gradOutputs[step], gradOutput) gradOutput = self._gradOutputs[step] local gradInputTable = stepmodule:updateGradInput({input, self:getHiddenState(step-1)[1]}, gradOutput) diff --git a/RecurrentAttention.lua b/RecurrentAttention.lua index 2f4e9c2..f20ff18 100644 --- a/RecurrentAttention.lua +++ b/RecurrentAttention.lua @@ -1,10 +1,10 @@ ------------------------------------------------------------------------ ---[[ RecurrentAttention ]]-- +--[[ RecurrentAttention ]]-- -- Ref. A. http://papers.nips.cc/paper/5542-recurrent-models-of-visual-attention.pdf -- B. http://incompleteideas.net/sutton/williams-92.pdf --- module which takes an RNN as argument with other --- hyper-parameters such as the maximum number of steps, --- action (actions sampling module like ReinforceNormal) and +-- module which takes an RNN as argument with other +-- hyper-parameters such as the maximum number of steps, +-- action (actions sampling module like ReinforceNormal) and ------------------------------------------------------------------------ local RecurrentAttention, parent = torch.class("nn.RecurrentAttention", "nn.AbstractSequencer") @@ -14,23 +14,23 @@ function RecurrentAttention:__init(rnn, action, nStep, hiddenSize) assert(torch.type(nStep) == 'number') assert(torch.type(hiddenSize) == 'table') assert(torch.type(hiddenSize[1]) == 'number', "Does not support table hidden layers" ) - + self.rnn = rnn -- we can decorate the module with a Recursor to make it AbstractRecurrent self.rnn = (not torch.isTypeOf(rnn, 'nn.AbstractRecurrent')) and nn.Recursor(rnn) or rnn - + -- samples an x,y actions for each example - self.action = (not torch.isTypeOf(action, 'nn.AbstractRecurrent')) and nn.Recursor(action) or action + self.action = (not torch.isTypeOf(action, 'nn.AbstractRecurrent')) and nn.Recursor(action) or action self.hiddenSize = hiddenSize self.nStep = nStep - + self.modules = {self.rnn, self.action} - + self.output = {} -- rnn output self.actions = {} -- action output - + self.forwardActions = false - + self.gradHidden = {} end @@ -38,9 +38,9 @@ function RecurrentAttention:updateOutput(input) self.rnn:forget() self.action:forget() local nDim = input:dim() - + for step=1,self.nStep do - + if step == 1 then -- sample an initial starting actions by forwarding zeros through the action self._initInput = self._initInput or input.new() @@ -50,12 +50,12 @@ function RecurrentAttention:updateOutput(input) -- sample actions from previous hidden activation (rnn output) self.actions[step] = self.action:updateOutput(self.output[step-1]) end - + -- rnn handles the recurrence internally local output = self.rnn:updateOutput{input, self.actions[step]} self.output[step] = self.forwardActions and {output, self.actions[step]} or output end - + return self.output end @@ -63,7 +63,7 @@ function RecurrentAttention:updateGradInput(input, gradOutput) assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps") assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") assert(#gradOutput == self.nStep, "gradOutput should have nStep elements") - + -- back-propagate through time (BPTT) for step=self.nStep,1,-1 do -- 1. backward through the action layer @@ -78,22 +78,22 @@ function RecurrentAttention:updateGradInput(input, gradOutput) end gradAction_ = self._gradAction end - + if step == self.nStep then - self.gradHidden[step] = nn.rnn.recursiveCopy(self.gradHidden[step], gradOutput_) + self.gradHidden[step] = nn.utils.recursiveCopy(self.gradHidden[step], gradOutput_) else -- gradHidden = gradOutput + gradAction - nn.rnn.recursiveAdd(self.gradHidden[step], gradOutput_) + nn.utils.recursiveAdd(self.gradHidden[step], gradOutput_) end - + if step == 1 then -- backward through initial starting actions self.action:updateGradInput(self._initInput, gradAction_) else local gradAction = self.action:updateGradInput(self.output[step-1], gradAction_) - self.gradHidden[step-1] = nn.rnn.recursiveCopy(self.gradHidden[step-1], gradAction) + self.gradHidden[step-1] = nn.utils.recursiveCopy(self.gradHidden[step-1], gradAction) end - + -- 2. backward through the rnn layer local gradInput = self.rnn:updateGradInput({input, self.actions[step]}, self.gradHidden[step])[1] if step == self.nStep then @@ -110,19 +110,19 @@ function RecurrentAttention:accGradParameters(input, gradOutput, scale) assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps") assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") assert(#gradOutput == self.nStep, "gradOutput should have nStep elements") - + -- back-propagate through time (BPTT) for step=self.nStep,1,-1 do -- 1. backward through the action layer local gradAction_ = self.forwardActions and gradOutput[step][2] or self._gradAction - + if step == 1 then -- backward through initial starting actions self.action:accGradParameters(self._initInput, gradAction_, scale) else self.action:accGradParameters(self.output[step-1], gradAction_, scale) end - + -- 2. backward through the rnn layer self.rnn:accGradParameters({input, self.actions[step]}, self.gradHidden[step], scale) end @@ -132,12 +132,12 @@ function RecurrentAttention:accUpdateGradParameters(input, gradOutput, lr) assert(self.rnn.step - 1 == self.nStep, "inconsistent rnn steps") assert(torch.type(gradOutput) == 'table', "expecting gradOutput table") assert(#gradOutput == self.nStep, "gradOutput should have nStep elements") - + -- backward through the action layers for step=self.nStep,1,-1 do -- 1. backward through the action layer local gradAction_ = self.forwardActions and gradOutput[step][2] or self._gradAction - + if step == 1 then -- backward through initial starting actions self.action:accUpdateGradParameters(self._initInput, gradAction_, lr) @@ -145,7 +145,7 @@ function RecurrentAttention:accUpdateGradParameters(input, gradOutput, lr) -- Note : gradOutput is ignored by REINFORCE modules so we give action.output as a dummy variable self.action:accUpdateGradParameters(self.output[step-1], gradAction_, lr) end - + -- 2. backward through the rnn layer self.rnn:accUpdateGradParameters({input, self.actions[step]}, self.gradHidden[step], lr) end diff --git a/Repeater.lua b/Repeater.lua index c98b06e..f114dbf 100644 --- a/Repeater.lua +++ b/Repeater.lua @@ -25,7 +25,7 @@ function Repeater:updateOutput(input) self.module:forget() -- TODO make copy outputs optional for step=1,self.seqlen do - self.output[step] = nn.rnn.recursiveCopy(self.output[step], self.module:updateOutput(input)) + self.output[step] = nn.utils.recursiveCopy(self.output[step], self.module:updateOutput(input)) end return self.output end @@ -39,9 +39,9 @@ function Repeater:updateGradInput(input, gradOutput) for step=self.seqlen,1,-1 do local gradInput = self.module:updateGradInput(input, gradOutput[step]) if step == self.seqlen then - self.gradInput = nn.rnn.recursiveCopy(self.gradInput, gradInput) + self.gradInput = nn.utils.recursiveCopy(self.gradInput, gradInput) else - nn.rnn.recursiveAdd(self.gradInput, gradInput) + nn.utils.recursiveAdd(self.gradInput, gradInput) end end diff --git a/deprecated/LSTM.lua b/deprecated/LSTM.lua index fbbee7f..5c4560c 100644 --- a/deprecated/LSTM.lua +++ b/deprecated/LSTM.lua @@ -251,8 +251,8 @@ function LSTM:_updateGradInput(input, gradOutput) local _gradOutput, gradCell = gradHiddenState[1], gradHiddenState[2] assert(_gradOutput and gradCell) - self._gradOutputs[step] = nn.rnn.recursiveCopy(self._gradOutputs[step], _gradOutput) - nn.rnn.recursiveAdd(self._gradOutputs[step], gradOutput) + self._gradOutputs[step] = nn.utils.recursiveCopy(self._gradOutputs[step], _gradOutput) + nn.utils.recursiveAdd(self._gradOutputs[step], gradOutput) gradOutput = self._gradOutputs[step] local inputTable = self:getHiddenState(step-1) diff --git a/init.lua b/init.lua index 3b016d5..6990302 100644 --- a/init.lua +++ b/init.lua @@ -4,7 +4,7 @@ require 'nn' -- create global rnn table: rnn = {} -rnn.version = 2.5 -- AbstractRecurrent(rho) -> AbstractRecurrent(module) +rnn.version = 2.6 -- zero-masking v2 : setZeroMask(zeroMask) -- lua 5.2 compat @@ -24,7 +24,6 @@ paths.require 'librnn' unpack = unpack or table.unpack -require('rnn.recursiveUtils') require('rnn.utils') -- extensions to existing nn.Module @@ -140,6 +139,7 @@ require('rnn.SeqReverseSequence') require('rnn.SeqBRNN') -- recurrent criterions: +require('rnn.AbstractSequencerCriterion') require('rnn.SequencerCriterion') require('rnn.RepeaterCriterion') require('rnn.MaskZeroCriterion') diff --git a/recursiveUtils.lua b/recursiveUtils.lua deleted file mode 100644 index 02a0341..0000000 --- a/recursiveUtils.lua +++ /dev/null @@ -1,155 +0,0 @@ - -function rnn.recursiveResizeAs(t1,t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = rnn.recursiveResizeAs(t1[key], t2[key]) - end - elseif torch.isTensor(t2) then - t1 = torch.isTensor(t1) and t1 or t2.new() - t1:resizeAs(t2) - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - -function rnn.recursiveSet(t1,t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = rnn.recursiveSet(t1[key], t2[key]) - end - elseif torch.isTensor(t2) then - t1 = torch.isTensor(t1) and t1 or t2.new() - t1:set(t2) - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - -function rnn.recursiveCopy(t1,t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = rnn.recursiveCopy(t1[key], t2[key]) - end - elseif torch.isTensor(t2) then - t1 = torch.isTensor(t1) and t1 or t2.new() - t1:resizeAs(t2):copy(t2) - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - -function rnn.recursiveAdd(t1, t2) - if torch.type(t2) == 'table' then - t1 = (torch.type(t1) == 'table') and t1 or {t1} - for key,_ in pairs(t2) do - t1[key], t2[key] = rnn.recursiveAdd(t1[key], t2[key]) - end - elseif torch.isTensor(t1) and torch.isTensor(t2) then - t1:add(t2) - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end - return t1, t2 -end - -function rnn.recursiveTensorEq(t1, t2) - if torch.type(t2) == 'table' then - local isEqual = true - if torch.type(t1) ~= 'table' then - return false - end - for key,_ in pairs(t2) do - isEqual = isEqual and rnn.recursiveTensorEq(t1[key], t2[key]) - end - return isEqual - elseif torch.isTensor(t1) and torch.isTensor(t2) then - local diff = t1-t2 - local err = diff:abs():max() - return err < 0.00001 - else - error("expecting nested tensors or tables. Got ".. - torch.type(t1).." and "..torch.type(t2).." instead") - end -end - -function rnn.recursiveNormal(t2) - if torch.type(t2) == 'table' then - for key,_ in pairs(t2) do - t2[key] = rnn.recursiveNormal(t2[key]) - end - elseif torch.isTensor(t2) then - t2:normal() - else - error("expecting tensor or table thereof. Got " - ..torch.type(t2).." instead") - end - return t2 -end - -function rnn.recursiveFill(t2, val) - if torch.type(t2) == 'table' then - for key,_ in pairs(t2) do - t2[key] = rnn.recursiveFill(t2[key], val) - end - elseif torch.isTensor(t2) then - t2:fill(val) - else - error("expecting tensor or table thereof. Got " - ..torch.type(t2).." instead") - end - return t2 -end - -function rnn.recursiveType(param, type_str) - if torch.type(param) == 'table' then - for i = 1, #param do - param[i] = rnn.recursiveType(param[i], type_str) - end - else - if torch.typename(param) and - torch.typename(param):find('torch%..+Tensor') then - param = param:type(type_str) - end - end - return param -end - -function rnn.recursiveSum(t2) - local sum = 0 - if torch.type(t2) == 'table' then - for key,_ in pairs(t2) do - sum = sum + rnn.recursiveSum(t2[key], val) - end - elseif torch.isTensor(t2) then - return t2:sum() - else - error("expecting tensor or table thereof. Got " - ..torch.type(t2).." instead") - end - return sum -end - -function rnn.recursiveNew(t2) - if torch.type(t2) == 'table' then - local t1 = {} - for key,_ in pairs(t2) do - t1[key] = rnn.recursiveNew(t2[key]) - end - return t1 - elseif torch.isTensor(t2) then - return t2.new() - else - error("expecting tensor or table thereof. Got " - ..torch.type(t2).." instead") - end -end diff --git a/utils.lua b/utils.lua index 053665f..34e29b6 100644 --- a/utils.lua +++ b/utils.lua @@ -24,4 +24,208 @@ function torch.getBuffer(namespace, buffername, classname) end return buffer -end \ No newline at end of file +end + +function torch.isByteTensor(tensor) + local typename = torch.typename(tensor) + if typename and typename:find('torch.*ByteTensor') then + return true + end + return false +end + +function nn.utils.getZeroMaskBatch(batch, zeroMask) + -- get first tensor + local first = nn.utils.recursiveGetFirst(batch) + first = first:contiguous():view(first:size(1), -1) -- collapse non-batch dims + + -- build mask (1 where norm is 0 in first) + local _zeroMask = torch.getBuffer('getZeroMaskBatch', '_zeroMask', first) + _zeroMask:norm(first, 2, 2) + zeroMask = zeroMask or ( + (torch.type(first) == 'torch.CudaTensor') and torch.CudaByteTensor() + or (torch.type(first) == 'torch.ClTensor') and torch.ClTensor() + or torch.ByteTensor() + ) + _zeroMask.eq(zeroMask, _zeroMask, 0) + return zeroMask:view(zeroMask:size(1)) +end + +function nn.utils.getZeroMaskSequence(sequence, zeroMask) + assert(torch.isTensor(sequence), "nn.utils.getZeroMaskSequence expecting tensor for arg 1") + assert(sequence:dim() >= 2, "nn.utils.getZeroMaskSequence expecting seqlen x batchsize [x ...] tensor for arg 1") + + sequence = sequence:contiguous():view(sequence:size(1), sequence:size(2), -1) + -- build mask (1 where norm is 0 in first) + local _zeroMask = torch.getBuffer('getZeroMaskSequence', '_zeroMask', sequence) + _zeroMask:norm(sequence, 2, 3) + zeroMask = zeroMask or ( + (torch.type(sequence) == 'torch.CudaTensor') and torch.CudaByteTensor() + or (torch.type(sequence) == 'torch.ClTensor') and torch.ClTensor() + or torch.ByteTensor() + ) + _zeroMask.eq(zeroMask, _zeroMask, 0) + return zeroMask:view(sequence:size(1), sequence:size(2)) +end + +function nn.utils.recursiveSet(t1,t2) + if torch.type(t2) == 'table' then + t1 = (torch.type(t1) == 'table') and t1 or {t1} + for key,_ in pairs(t2) do + t1[key], t2[key] = nn.utils.recursiveSet(t1[key], t2[key]) + end + for i=#t2+1,#t1 do + t1[i] = nil + end + elseif torch.isTensor(t2) then + t1 = torch.isTensor(t1) and t1 or t2.new() + t1:set(t2) + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end + return t1, t2 +end + +function nn.utils.recursiveTensorEq(t1, t2) + if torch.type(t2) == 'table' then + local isEqual = true + if torch.type(t1) ~= 'table' then + return false + end + for key,_ in pairs(t2) do + isEqual = isEqual and nn.utils.recursiveTensorEq(t1[key], t2[key]) + end + return isEqual + elseif torch.isTensor(t1) and torch.isTensor(t2) then + local diff = t1-t2 + local err = diff:abs():max() + return err < 0.00001 + else + error("expecting nested tensors or tables. Got ".. + torch.type(t1).." and "..torch.type(t2).." instead") + end +end + +function nn.utils.recursiveNormal(t2) + if torch.type(t2) == 'table' then + for key,_ in pairs(t2) do + t2[key] = nn.utils.recursiveNormal(t2[key]) + end + elseif torch.isTensor(t2) then + t2:normal() + else + error("expecting tensor or table thereof. Got " + ..torch.type(t2).." instead") + end + return t2 +end + +function nn.utils.recursiveSum(t2) + local sum = 0 + if torch.type(t2) == 'table' then + for key,_ in pairs(t2) do + sum = sum + nn.utils.recursiveSum(t2[key], val) + end + elseif torch.isTensor(t2) then + return t2:sum() + else + error("expecting tensor or table thereof. Got " + ..torch.type(t2).." instead") + end + return sum +end + +function nn.utils.recursiveNew(t2) + if torch.type(t2) == 'table' then + local t1 = {} + for key,_ in pairs(t2) do + t1[key] = nn.utils.recursiveNew(t2[key]) + end + return t1 + elseif torch.isTensor(t2) then + return t2.new() + else + error("expecting tensor or table thereof. Got " + ..torch.type(t2).." instead") + end +end + +function nn.utils.recursiveGetFirst(input) + if torch.type(input) == 'table' then + return nn.utils.recursiveGetFirst(input[1]) + else + assert(torch.isTensor(input)) + return input + end +end + +-- in-place set tensor to zero where zeroMask is 1 +function nn.utils.recursiveZeroMask(tensor, mask) + if torch.type(tensor) == 'table' then + for k,tensor_k in ipairs(tensor) do + nn.utils.recursiveMask(tensor_k) + end + else + assert(torch.isTensor(tensor)) + + local tensorSize = tensor:size():fill(1) + tensorSize[1] = tensor:size(1) + assert(zeroMask:dim() <= tensor:dim()) + zeroMask = zeroMask:view(tensorSize):expandAs(tensor) + -- set tensor to zero where zeroMask is 1 + tensor:maskedFill(zeroMask, 0) + end + return tensor +end + +function nn.utils.recursiveDiv(tensor, scalar) + if torch.type(tensor) == 'table' then + for j=1,#tensor do + nn.utils.recursiveDiv(tensor[j], scalar) + end + else + tensor:div(scalar) + end +end + +function nn.utils.recursiveIndex(dst, src, indices) + if torch.type(src) == 'table' then + dst = torch.type(dst) == 'table' and dst or {} + for k,v in ipairs(src) do + dst[k] = nn.utils.recursiveIndex(dst[k], v, indices) + end + for i=#src+1,#dst do + dst[i] = nil + end + else + assert(torch.isTensor(src)) + dst = torch.isTensor(dst) and dst or src.new() + + dst:index(src, 1, indices) + end + return dst +end + +function nn.utils.recursiveIndexCopy(dst, indices, src) + if torch.type(src) == 'table' then + dst = (torch.type(dst) == 'table') and dst or {dst} + for key,src_ in pairs(src) do + dst[key] = self:recursiveMaskGradInput(dst[key], indices, src_) + end + for i=#src+1,#dst do + dst[i] = nil + end + elseif torch.isTensor(input) then + dst = torch.isTensor(dst) and dst or input.new() + dst:resizeAs(input):zero() + if indices:nElement() > 0 then + assert(src) + dst:indexCopy(1, indices, src) + end + else + error("expecting nested tensors or tables. Got ".. + torch.type(dst).." and "..torch.type(input).." instead") + end + return dst +end