Skip to content

Commit

Permalink
TotalDropout
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Aug 18, 2015
1 parent 489b0f2 commit b429afe
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 0 deletions.
38 changes: 38 additions & 0 deletions TotalDropout.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
------------------------------------------------------------------------
--[[ TotalDropout ]]--
-- Like vanilla Dropout, but on the entire inputs.
-- So either the input is entirely forwarded or entirely zeroed.
------------------------------------------------------------------------
local TotalDropout, parent = torch.class("nn.TotalDropout", "nn.Module")

function TotalDropout:__init(p)
self.p = p or 0.5
self.train = true
if self.p >= 1 or self.p < 0 then
error('<TotalDropout> illegal percentage, must be 0 <= p < 1')
end
parent.__init(self)
end

function TotalDropout:updateOutput(input)
self.output:resizeAs(input):copy(input)
if self.train then
self.noise = torch.bernoulli(1-self.p)
self.output:mul(self.noise)
end
return self.output
end

function TotalDropout:updateGradInput(input, gradOutput)
if self.train then
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
self.gradInput:mul(self.noise) -- simply mask the gradients with the noise vector
else
error('backprop only defined while training')
end
return self.gradInput
end

function TotalDropout:__tostring__()
return string.format('%s(%f)', torch.type(self), self.p)
end
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ torch.include('dpnn', 'SpatialUniformCrop.lua')
torch.include('dpnn', 'SpatialGlimpse.lua')
torch.include('dpnn', 'ArgMax.lua')
torch.include('dpnn', 'CategoricalEntropy.lua')
torch.include('dpnn', 'TotalDropout.lua')

-- REINFORCE
torch.include('dpnn', 'Reinforce.lua')
Expand Down
22 changes: 22 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -727,6 +727,28 @@ function dpnntest.CategoricalEntropy()
mytester:assertTensorEq(gradInput2, gradInput, 0.000001, "CategoricalEntropy gradInput err")
end

function dpnntest.TotalDropout()
local batchSize = 4
local inputSize = 3
local input = torch.randn(batchSize, inputSize)
local gradOutput = torch.randn(batchSize, inputSize)
local td = nn.TotalDropout()
local nOne = 0
for i=1,10 do
local output = td:forward(input)
local gradInput = td:backward(input, gradOutput)
if td.noise == 0 then
mytester:assert(output:sum() == 0, "TotalDropout forward 0 err")
mytester:assert(gradInput:sum() == 0, "TotalDropout backward 0 err")
else
mytester:assertTensorEq(output, input, 0.000001, "TotalDropout forward 1 err")
mytester:assertTensorEq(gradInput, gradOutput, 0.000001, "TotalDropout backward 1 err")
nOne = nOne + 1
end
end
mytester:assert(nOne < 10 and nOne > 1, "TotalDropout bernoulli error")
end

function dpnnbigtest.Reinforce()
-- let us try to reinforce an mlp to learn a simple distribution
local n = 10
Expand Down

0 comments on commit b429afe

Please sign in to comment.