From b429afe8cec69fe29d687f4a0152a9584e2564fc Mon Sep 17 00:00:00 2001 From: nicholas-leonard Date: Tue, 18 Aug 2015 14:53:01 -0400 Subject: [PATCH] TotalDropout --- TotalDropout.lua | 38 ++++++++++++++++++++++++++++++++++++++ init.lua | 1 + test/test.lua | 22 ++++++++++++++++++++++ 3 files changed, 61 insertions(+) create mode 100644 TotalDropout.lua diff --git a/TotalDropout.lua b/TotalDropout.lua new file mode 100644 index 0000000..b239fec --- /dev/null +++ b/TotalDropout.lua @@ -0,0 +1,38 @@ +------------------------------------------------------------------------ +--[[ TotalDropout ]]-- +-- Like vanilla Dropout, but on the entire inputs. +-- So either the input is entirely forwarded or entirely zeroed. +------------------------------------------------------------------------ +local TotalDropout, parent = torch.class("nn.TotalDropout", "nn.Module") + +function TotalDropout:__init(p) + self.p = p or 0.5 + self.train = true + if self.p >= 1 or self.p < 0 then + error(' illegal percentage, must be 0 <= p < 1') + end + parent.__init(self) +end + +function TotalDropout:updateOutput(input) + self.output:resizeAs(input):copy(input) + if self.train then + self.noise = torch.bernoulli(1-self.p) + self.output:mul(self.noise) + end + return self.output +end + +function TotalDropout:updateGradInput(input, gradOutput) + if self.train then + self.gradInput:resizeAs(gradOutput):copy(gradOutput) + self.gradInput:mul(self.noise) -- simply mask the gradients with the noise vector + else + error('backprop only defined while training') + end + return self.gradInput +end + +function TotalDropout:__tostring__() + return string.format('%s(%f)', torch.type(self), self.p) +end diff --git a/init.lua b/init.lua index c596aee..5f4e449 100644 --- a/init.lua +++ b/init.lua @@ -71,6 +71,7 @@ torch.include('dpnn', 'SpatialUniformCrop.lua') torch.include('dpnn', 'SpatialGlimpse.lua') torch.include('dpnn', 'ArgMax.lua') torch.include('dpnn', 'CategoricalEntropy.lua') +torch.include('dpnn', 'TotalDropout.lua') -- REINFORCE torch.include('dpnn', 'Reinforce.lua') diff --git a/test/test.lua b/test/test.lua index 0b1fcf8..a12e042 100644 --- a/test/test.lua +++ b/test/test.lua @@ -727,6 +727,28 @@ function dpnntest.CategoricalEntropy() mytester:assertTensorEq(gradInput2, gradInput, 0.000001, "CategoricalEntropy gradInput err") end +function dpnntest.TotalDropout() + local batchSize = 4 + local inputSize = 3 + local input = torch.randn(batchSize, inputSize) + local gradOutput = torch.randn(batchSize, inputSize) + local td = nn.TotalDropout() + local nOne = 0 + for i=1,10 do + local output = td:forward(input) + local gradInput = td:backward(input, gradOutput) + if td.noise == 0 then + mytester:assert(output:sum() == 0, "TotalDropout forward 0 err") + mytester:assert(gradInput:sum() == 0, "TotalDropout backward 0 err") + else + mytester:assertTensorEq(output, input, 0.000001, "TotalDropout forward 1 err") + mytester:assertTensorEq(gradInput, gradOutput, 0.000001, "TotalDropout backward 1 err") + nOne = nOne + 1 + end + end + mytester:assert(nOne < 10 and nOne > 1, "TotalDropout bernoulli error") +end + function dpnnbigtest.Reinforce() -- let us try to reinforce an mlp to learn a simple distribution local n = 10