Skip to content

Commit

Permalink
MaskZero v2
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed May 3, 2017
1 parent 69bfa49 commit e7c456b
Show file tree
Hide file tree
Showing 23 changed files with 368 additions and 1,077 deletions.
75 changes: 47 additions & 28 deletions AbstractRecurrent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,18 @@ function AbstractRecurrent:getStepModule(step)
return stepmodule
end

function AbstractRecurrent:maskZero(nInputDim)
local stepmodule = nn.MaskZero(self.modules[1], nInputDim, true)
self.sharedClones = {stepmodule}
self.modules[1] = stepmodule
return self
end

function AbstractRecurrent:trimZero(nInputDim)
if torch.typename(self)=='nn.GRU' and self.p ~= 0 then
assert(self.mono, "TrimZero for BGRU needs `mono` option.")
function AbstractRecurrent:updateOutput(input)
if self.zeroMask then
-- where zeroMask = 1, the past is forgotten,
-- that is, the output/gradOutput is zero'd
local stepmodule = self:getStepModule(self.step)
self.zeroMaskStep = self.zeroMaskStep + 1
if self.zeroMaskStep > self.zeroMask:size(1) then
error"AbstractRecurrent.updateOutput called more times than self.zeroMask:size(1)"
end
stepmodule:setZeroMask(self.zeroMask[self.zeroMaskStep])
end
local stepmodule = nn.TrimZero(self.modules[1], nInputDim, true)
self.sharedClones = {stepmodule}
self.modules[1] = stepmodule
return self
end

function AbstractRecurrent:updateOutput(input)
-- feed-forward for one time-step
self.output = self:_updateOutput(input)

Expand All @@ -64,6 +58,10 @@ function AbstractRecurrent:updateOutput(input)
end

function AbstractRecurrent:updateGradInput(input, gradOutput)
if self.zeroMask and self.zeroMask:size(1) ~= self.zeroMaskStep then
error"AbstractRecurrent.updateOutput called less times than self.zeroMask:size(1)"
end

-- updateGradInput should be called in reverse order of time
self.updateGradInputStep = self.updateGradInputStep or self.step

Expand All @@ -86,7 +84,7 @@ function AbstractRecurrent:accGradParameters(input, gradOutput, scale)
self.accGradParametersStep = self.accGradParametersStep - 1
end

-- goes hand in hand with the next method : forget()
-- goes hand in hand with forget()
-- this methods brings the oldest memory to the current step
function AbstractRecurrent:recycle()
self.nSharedClone = self.nSharedClone or _.size(self.sharedClones)
Expand All @@ -113,6 +111,7 @@ function nn.AbstractRecurrent:clearState()
clone:clearState()
end
self.modules[1]:clearState()
self.zeroMask = nil
return parent.clearState(self)
end

Expand Down Expand Up @@ -189,6 +188,32 @@ function AbstractRecurrent:type(type, tensorcache)
end)
end

function AbstractRecurrent:maskZero(v1)
if not self.maskzero then
self.maskzero = true
local stepmodule = nn.MaskZero(self.modules[1], v1)
self.sharedClones = {stepmodule}
self.modules[1] = stepmodule
end
return self
end

function AbstractRecurrent:setZeroMask(zeroMask)
if zeroMask == false then
self.zeroMask = false
for k,stepmodule in pairs(self.sharedClones) do
stepmodule:setZeroMask(zeroMask)
end
elseif torch.isTypeOf(self.modules[1], 'nn.AbstractRecurrent') then
self.modules[1]:setZeroMask(zeroMask)
else
assert(zeroMask:dim() >= 2, "Expecting dim >= 2 for zeroMask. For example, seqlen x batchsize")
-- reserve for later. Each step will be masked in updateOutput
self.zeroMask = zeroMask
self.zeroMaskStep = 0
end
end

function AbstractRecurrent:training()
return self:includingSharedClones(function()
return parent.training(self)
Expand Down Expand Up @@ -261,18 +286,12 @@ function AbstractRecurrent:setGradHiddenState(step, hiddenState)
error"Not Implemented"
end

-- backwards compatibility
AbstractRecurrent.recursiveResizeAs = rnn.recursiveResizeAs
AbstractRecurrent.recursiveSet = rnn.recursiveSet
AbstractRecurrent.recursiveCopy = rnn.recursiveCopy
AbstractRecurrent.recursiveAdd = rnn.recursiveAdd
AbstractRecurrent.recursiveTensorEq = rnn.recursiveTensorEq
AbstractRecurrent.recursiveNormal = rnn.recursiveNormal

function AbstractRecurrent:__tostring__()
if self.inputSize and self.outputSize then
return self.__typename .. string.format("(%d -> %d)", self.inputSize, self.outputSize)
local inputsize = self.inputsize or self.inputSize
local outputsize = self.outputsize or self.outputSize
if inputsize and outputsize then
return self.__typename .. string.format("(%d -> %d)", inputsize, outputsize)
else
return parent.__tostring__(self)
return self.__typename
end
end
5 changes: 2 additions & 3 deletions AbstractSequencerCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ local AbstractSequencerCriterion, parent = torch.class('nn.AbstractSequencerCrit

function AbstractSequencerCriterion:__init(criterion, sizeAverage)
parent.__init(self)
self.criterion = criterion
if torch.isTypeOf(criterion, 'nn.ModuleCriterion') then
error(torch.type(self).." shouldn't decorate a ModuleCriterion. "..
"Instead, try the other way around : "..
Expand All @@ -20,14 +19,14 @@ function AbstractSequencerCriterion:__init(criterion, sizeAverage)
else
self.sizeAverage = false
end
self.clones = {}
self.clones = {criterion}
end

function AbstractSequencerCriterion:getStepCriterion(step)
assert(step, "expecting step at arg 1")
local criterion = self.clones[step]
if not criterion then
criterion = self.criterion:clone()
criterion = self.clones[1]:clone()
self.clones[step] = criterion
end
return criterion
Expand Down
1 change: 0 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ SET(luasrc
SeqReverseSequence.lua
Sequencer.lua
SequencerCriterion.lua
TrimZero.lua
ZeroGrad.lua
test/bigtest.lua
test/test.lua
Expand Down
2 changes: 1 addition & 1 deletion LinearRNN.lua
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ function LinearRNN:__init(inputsize, outputsize, transfer)
end

function LinearRNN:__tostring__()
return torch.type(self) .. "(" .. self.inputsize .. ", " .. self.outputsize ..")"
return torch.type(self) .. "(" .. self.inputsize .. " -> " .. self.outputsize ..")"
end
8 changes: 2 additions & 6 deletions LookupRNN.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ function LookupRNN:__init(nindex, outputsize, transfer, merge)
merge = merge or nn.CAddTable()
local stepmodule = nn.Sequential() -- input is {x[t], h[t-1]}
:add(nn.ParallelTable()
:add(nn.LookupTable(nindex, outputsize)) -- input layer
:add(nn.LookupTableMaskZero(nindex, outputsize)) -- input layer
:add(nn.Linear(outputsize, outputsize))) -- recurrent layer
:add(merge)
:add(transfer)
Expand All @@ -14,10 +14,6 @@ function LookupRNN:__init(nindex, outputsize, transfer, merge)
self.outputsize = outputsize
end

function LookupRNN:maskZero()
error"Not Implemented"
end

function LookupRNN:__tostring__()
return torch.type(self) .. "(" .. self.nindex .. ", " .. self.outputsize ..")"
return torch.type(self) .. "(" .. self.nindex .. " -> " .. self.outputsize ..")"
end
6 changes: 3 additions & 3 deletions LookupTableMaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@ function LookupTableMaskZero:__init(nIndex, nOutput)
end

function LookupTableMaskZero:updateOutput(input)
self.weight[1]:zero()
self.weight[1]:zero()
if self.__input and (torch.type(self.__input) ~= torch.type(input)) then
self.__input = nil -- fixes old casting bug
end
self.__input = self.__input or input.new()
self.__input:resizeAs(input):add(input, 1)
return parent.updateOutput(self, self.__input)
return parent.updateOutput(self, self.__input)
end

function LookupTableMaskZero:accGradParameters(input, gradOutput, scale)
parent.accGradParameters(self, self.__input, gradOutput, scale)
parent.accGradParameters(self, self.__input, gradOutput, scale)
end

function LookupTableMaskZero:type(type, cache)
Expand Down
128 changes: 56 additions & 72 deletions MaskZero.lua
Original file line number Diff line number Diff line change
@@ -1,99 +1,83 @@
------------------------------------------------------------------------
--[[ MaskZero ]]--
-- Decorator that zeroes the output rows of the encapsulated module
-- for commensurate input rows which are tensors of zeros
-- Zeroes the elements of the state tensors
-- (output/gradOutput/input/gradInput) of the encapsulated module
-- for commensurate elements that are 1 in self.zeroMask.
-- By default only output/gradOutput are zeroMasked.
-- self.zeroMask is set with setZeroMask(zeroMask).
-- Only works in batch-mode.
-- Note that when input/gradInput are zeroMasked, it is in-place
------------------------------------------------------------------------
local MaskZero, parent = torch.class("nn.MaskZero", "nn.Decorator")

function MaskZero:__init(module, nInputDim, silent)
function MaskZero:__init(module, v1, maskinput, maskoutput)
parent.__init(self, module)
assert(torch.isTypeOf(module, 'nn.Module'))
if torch.isTypeOf(module, 'nn.AbstractRecurrent') and not silent then
print("Warning : you are most likely using MaskZero the wrong way. "
.."You should probably use AbstractRecurrent:maskZero() so that "
.."it wraps the internal AbstractRecurrent.recurrentModule instead of "
.."wrapping the AbstractRecurrent module itself.")
end
assert(torch.type(nInputDim) == 'number', 'Expecting nInputDim number at arg 1')
self.nInputDim = nInputDim
end

function MaskZero:recursiveGetFirst(input)
if torch.type(input) == 'table' then
return self:recursiveGetFirst(input[1])
else
assert(torch.isTensor(input))
return input
end
end

function MaskZero:recursiveMask(output, input, mask)
if torch.type(input) == 'table' then
output = torch.type(output) == 'table' and output or {}
for k,v in ipairs(input) do
output[k] = self:recursiveMask(output[k], v, mask)
end
else
assert(torch.isTensor(input))
output = torch.isTensor(output) and output or input.new()

-- make sure mask has the same dimension as the input tensor
local inputSize = input:size():fill(1)
if self.batchmode then
inputSize[1] = input:size(1)
end
mask:resize(inputSize)
-- build mask
local zeroMask = mask:expandAs(input)
output:resizeAs(input):copy(input)
output:maskedFill(zeroMask, 0)
end
return output
self.maskinput = maskinput -- defaults to false
self.maskoutput = maskoutput == nil and true or maskoutput -- defaults to true
self.v2 = not v1
end

function MaskZero:updateOutput(input)
-- recurrent module input is always the first one
local rmi = self:recursiveGetFirst(input):contiguous()
if rmi:dim() == self.nInputDim then
self.batchmode = false
rmi = rmi:view(-1) -- collapse dims
elseif rmi:dim() - 1 == self.nInputDim then
self.batchmode = true
rmi = rmi:view(rmi:size(1), -1) -- collapse non-batch dims
else
error("nInputDim error: "..rmi:dim()..", "..self.nInputDim)
if self.v2 then
assert(self.zeroMask ~= nil, "MaskZero expecting zeroMask tensor or false")
else -- backwards compat
self.zeroMask = nn.utils.getZeroMaskBatch(input, self.zeroMask)
end

-- build mask
local vectorDim = rmi:dim()
self._zeroMask = self._zeroMask or rmi.new()
self._zeroMask:norm(rmi, 2, vectorDim)
self.zeroMask = self.zeroMask or (
(torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor()
or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor()
or torch.ByteTensor()
)
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0)
if self.maskinput and self.zeroMask then
nn.utils.recursiveZeroMask(input, self.zeroMask)
end

-- forward through decorated module
local output = self.modules[1]:updateOutput(input)

self.output = self:recursiveMask(self.output, output, self.zeroMask)
if self.maskoutput and self.zeroMask then
self.output = nn.utils.recursiveCopy(self.output, output)
nn.utils.recursiveZeroMask(self.output, self.zeroMask)
else
self.output = output
end

return self.output
end

function MaskZero:updateGradInput(input, gradOutput)
-- zero gradOutputs before backpropagating through decorated module
self.gradOutput = self:recursiveMask(self.gradOutput, gradOutput, self.zeroMask)
assert(self.zeroMask ~= nil, "MaskZero expecting zeroMask tensor or false")

if self.maskoutput and self.zeroMask then
self.gradOutput = nn.utils.recursiveCopy(self.gradOutput, gradOutput)
nn.utils.recursiveZeroMask(self.gradOutput, self.zeroMask)
gradOutput = self.gradOutput
end

self.gradInput = self.modules[1]:updateGradInput(input, gradOutput)

if self.maskinput and self.zeroMask then
nn.utils.recursiveZeroMask(self.gradInput, self.zeroMask)
end

self.gradInput = self.modules[1]:updateGradInput(input, self.gradOutput)
return self.gradInput
end

function MaskZero:type(type, ...)
function MaskZero:clearState()
self.output = nil
self.gradInput = nil
self.zeroMask = nil
self._zeroMask = nil
self._maskbyte = nil
self._maskindices = nil
return self
end

function MaskZero:type(type, ...)
self:clearState()
return parent.type(self, type, ...)
end

function MaskZero:setZeroMask(zeroMask)
if zeroMask == false then
self.zeroMask = false
else
assert(torch.isByteTensor(zeroMask))
assert(zeroMask:isContiguous())
self.zeroMask = zeroMask
end
end
5 changes: 2 additions & 3 deletions MaskZeroCriterion.lua
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
------------------------------------------------------------------------
local MaskZeroCriterion, parent = torch.class("nn.MaskZeroCriterion", "nn.Criterion")

function MaskZeroCriterion:__init(criterion)
function MaskZeroCriterion:__init(criterion, v1)
parent.__init(self)
self.criterion = criterion
assert(torch.isTypeOf(criterion, 'nn.Criterion'))
self.v2 = true
self.v2 = not v1
end

function MaskZeroCriterion:updateOutput(input, target)
Expand Down Expand Up @@ -39,7 +39,6 @@ function MaskZeroCriterion:updateOutput(input, target)
-- indexSelect the input
self.input = nn.utils.recursiveIndex(self.input, input, 1, self._indices)
self.target = nn.utils.recursiveIndex(self.target, target, 1, self._indices)

-- forward through decorated criterion
self.output = self.criterion:updateOutput(self.input, self.target)
end
Expand Down
Loading

0 comments on commit e7c456b

Please sign in to comment.