Skip to content

Commit

Permalink
Merge branch 'ywelement-zerosoutput'
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Nov 19, 2015
2 parents fabfefd + 19222fb commit 98051e2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
7 changes: 7 additions & 0 deletions MaskZero.lua
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ function MaskZero:recursiveMask(output, input, mask)
assert(torch.isTensor(input))
output = torch.isTensor(output) and output or input.new()

-- make sure mask has the same dimenion as the input tensor
local inputSize = input:size():fill(1)
if input:dim() - 1 == self.nInputDim then
inputSize[1] = input:size(1)
end
mask:resize(inputSize)
-- build mask
local zeroMask = mask:expandAs(input)
output:resizeAs(input):copy(input)
output:maskedFill(zeroMask, 0)
Expand Down
28 changes: 28 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2318,6 +2318,34 @@ function rnntest.MaskZero()
local input = torch.rand(name == 'lstm' and 3 or 2, 5, 10)
local err = jac.testJacobian(module,input)
mytester:assertlt(err, precision, 'batch error on state for ' .. name)

-- full test on convolution and linear modules
local module = nn.Sequential() :add( nn.ParallelTable() :add(nn.SpatialConvolution(1,2,3,3)) :add(nn.Linear(100,2)) )
--module = module:float()
local batchNum = 5
local input = {torch.rand(batchNum,1,10,10), torch.rand(batchNum,100)}
local zeroRowNum = 2
for i = 1,#input do
input[i]:narrow(1,1,zeroRowNum):zero()
end
--module = nn.MaskZero(module, 3)
local output = module:forward(input)
for i = 1,#input do
for j = 1,batchNum do
local rmi = input[i][j]:view(-1) -- collapse dims
local vectorDim = rmi:dim()
local rn = rmi.new()
rn:norm(rmi, 2, vectorDim)
local err = rn[1]
if j<=zeroRowNum then
-- check zero outputs
mytester:assertlt(err, precision, 'batch ' ..i.. ':' ..j.. ' error on state for ' .. name)
else
-- check non-zero outputs
mytester:assertgt(err, precision, 'batch ' ..i.. ':' ..j.. ' error on state for ' .. name)
end
end
end
end
end

Expand Down

0 comments on commit 98051e2

Please sign in to comment.