-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Nicholas Leonard
committed
May 3, 2017
1 parent
69bfa49
commit e7c456b
Showing
23 changed files
with
368 additions
and
1,077 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,99 +1,83 @@ | ||
------------------------------------------------------------------------ | ||
--[[ MaskZero ]]-- | ||
-- Decorator that zeroes the output rows of the encapsulated module | ||
-- for commensurate input rows which are tensors of zeros | ||
-- Zeroes the elements of the state tensors | ||
-- (output/gradOutput/input/gradInput) of the encapsulated module | ||
-- for commensurate elements that are 1 in self.zeroMask. | ||
-- By default only output/gradOutput are zeroMasked. | ||
-- self.zeroMask is set with setZeroMask(zeroMask). | ||
-- Only works in batch-mode. | ||
-- Note that when input/gradInput are zeroMasked, it is in-place | ||
------------------------------------------------------------------------ | ||
local MaskZero, parent = torch.class("nn.MaskZero", "nn.Decorator") | ||
|
||
function MaskZero:__init(module, nInputDim, silent) | ||
function MaskZero:__init(module, v1, maskinput, maskoutput) | ||
parent.__init(self, module) | ||
assert(torch.isTypeOf(module, 'nn.Module')) | ||
if torch.isTypeOf(module, 'nn.AbstractRecurrent') and not silent then | ||
print("Warning : you are most likely using MaskZero the wrong way. " | ||
.."You should probably use AbstractRecurrent:maskZero() so that " | ||
.."it wraps the internal AbstractRecurrent.recurrentModule instead of " | ||
.."wrapping the AbstractRecurrent module itself.") | ||
end | ||
assert(torch.type(nInputDim) == 'number', 'Expecting nInputDim number at arg 1') | ||
self.nInputDim = nInputDim | ||
end | ||
|
||
function MaskZero:recursiveGetFirst(input) | ||
if torch.type(input) == 'table' then | ||
return self:recursiveGetFirst(input[1]) | ||
else | ||
assert(torch.isTensor(input)) | ||
return input | ||
end | ||
end | ||
|
||
function MaskZero:recursiveMask(output, input, mask) | ||
if torch.type(input) == 'table' then | ||
output = torch.type(output) == 'table' and output or {} | ||
for k,v in ipairs(input) do | ||
output[k] = self:recursiveMask(output[k], v, mask) | ||
end | ||
else | ||
assert(torch.isTensor(input)) | ||
output = torch.isTensor(output) and output or input.new() | ||
|
||
-- make sure mask has the same dimension as the input tensor | ||
local inputSize = input:size():fill(1) | ||
if self.batchmode then | ||
inputSize[1] = input:size(1) | ||
end | ||
mask:resize(inputSize) | ||
-- build mask | ||
local zeroMask = mask:expandAs(input) | ||
output:resizeAs(input):copy(input) | ||
output:maskedFill(zeroMask, 0) | ||
end | ||
return output | ||
self.maskinput = maskinput -- defaults to false | ||
self.maskoutput = maskoutput == nil and true or maskoutput -- defaults to true | ||
self.v2 = not v1 | ||
end | ||
|
||
function MaskZero:updateOutput(input) | ||
-- recurrent module input is always the first one | ||
local rmi = self:recursiveGetFirst(input):contiguous() | ||
if rmi:dim() == self.nInputDim then | ||
self.batchmode = false | ||
rmi = rmi:view(-1) -- collapse dims | ||
elseif rmi:dim() - 1 == self.nInputDim then | ||
self.batchmode = true | ||
rmi = rmi:view(rmi:size(1), -1) -- collapse non-batch dims | ||
else | ||
error("nInputDim error: "..rmi:dim()..", "..self.nInputDim) | ||
if self.v2 then | ||
assert(self.zeroMask ~= nil, "MaskZero expecting zeroMask tensor or false") | ||
else -- backwards compat | ||
self.zeroMask = nn.utils.getZeroMaskBatch(input, self.zeroMask) | ||
end | ||
|
||
-- build mask | ||
local vectorDim = rmi:dim() | ||
self._zeroMask = self._zeroMask or rmi.new() | ||
self._zeroMask:norm(rmi, 2, vectorDim) | ||
self.zeroMask = self.zeroMask or ( | ||
(torch.type(rmi) == 'torch.CudaTensor') and torch.CudaByteTensor() | ||
or (torch.type(rmi) == 'torch.ClTensor') and torch.ClTensor() | ||
or torch.ByteTensor() | ||
) | ||
self._zeroMask.eq(self.zeroMask, self._zeroMask, 0) | ||
if self.maskinput and self.zeroMask then | ||
nn.utils.recursiveZeroMask(input, self.zeroMask) | ||
end | ||
|
||
-- forward through decorated module | ||
local output = self.modules[1]:updateOutput(input) | ||
|
||
self.output = self:recursiveMask(self.output, output, self.zeroMask) | ||
if self.maskoutput and self.zeroMask then | ||
self.output = nn.utils.recursiveCopy(self.output, output) | ||
nn.utils.recursiveZeroMask(self.output, self.zeroMask) | ||
else | ||
self.output = output | ||
end | ||
|
||
return self.output | ||
end | ||
|
||
function MaskZero:updateGradInput(input, gradOutput) | ||
-- zero gradOutputs before backpropagating through decorated module | ||
self.gradOutput = self:recursiveMask(self.gradOutput, gradOutput, self.zeroMask) | ||
assert(self.zeroMask ~= nil, "MaskZero expecting zeroMask tensor or false") | ||
|
||
if self.maskoutput and self.zeroMask then | ||
self.gradOutput = nn.utils.recursiveCopy(self.gradOutput, gradOutput) | ||
nn.utils.recursiveZeroMask(self.gradOutput, self.zeroMask) | ||
gradOutput = self.gradOutput | ||
end | ||
|
||
self.gradInput = self.modules[1]:updateGradInput(input, gradOutput) | ||
|
||
if self.maskinput and self.zeroMask then | ||
nn.utils.recursiveZeroMask(self.gradInput, self.zeroMask) | ||
end | ||
|
||
self.gradInput = self.modules[1]:updateGradInput(input, self.gradOutput) | ||
return self.gradInput | ||
end | ||
|
||
function MaskZero:type(type, ...) | ||
function MaskZero:clearState() | ||
self.output = nil | ||
self.gradInput = nil | ||
self.zeroMask = nil | ||
self._zeroMask = nil | ||
self._maskbyte = nil | ||
self._maskindices = nil | ||
return self | ||
end | ||
|
||
function MaskZero:type(type, ...) | ||
self:clearState() | ||
return parent.type(self, type, ...) | ||
end | ||
|
||
function MaskZero:setZeroMask(zeroMask) | ||
if zeroMask == false then | ||
self.zeroMask = false | ||
else | ||
assert(torch.isByteTensor(zeroMask)) | ||
assert(zeroMask:isContiguous()) | ||
self.zeroMask = zeroMask | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.