From f023676665ced10e99124f27bd381237aab64288 Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Tue, 13 Oct 2015 17:32:25 -0400 Subject: [PATCH] fix division by zero error --- ReinforceCategorical.lua | 8 +++++++- test/test.lua | 3 ++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/ReinforceCategorical.lua b/ReinforceCategorical.lua index fb19f52..7f66e21 100644 --- a/ReinforceCategorical.lua +++ b/ReinforceCategorical.lua @@ -14,6 +14,9 @@ function ReinforceCategorical:updateOutput(input) self._index = self._index or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaTensor() or torch.LongTensor()) if self.stochastic or self.train ~= false then -- sample from categorical with p = input + self._input = self._input or input.new() + -- prevent division by zero error (see updateGradInput) + self._input:resizeAs(input):copy(input):add(0.00000001) input.multinomial(self._index, input, 1) -- one hot encoding self.output:zero() @@ -36,7 +39,10 @@ function ReinforceCategorical:updateGradInput(input, gradOutput) -- d p 0 otherwise self.gradInput:resizeAs(input):zero() self.gradInput:copy(self.output) - self.gradInput:cdiv(input) + self._input = self._input or input.new() + -- prevent division by zero error + self._input:resizeAs(input):copy(input):add(0.00000001) + self.gradInput:cdiv(self._input) -- multiply by reward self.gradInput:cmul(self:rewardAs(input)) diff --git a/test/test.lua b/test/test.lua index 7e0f165..26875bc 100644 --- a/test/test.lua +++ b/test/test.lua @@ -644,7 +644,7 @@ function dpnntest.ReinforceCategorical() rc:reinforce(reward) local gradInput = rc:updateGradInput(input, gradOutput) local gradInput2 = output:clone() - gradInput2:cdiv(input) + gradInput2:cdiv(input+0.00000001) local reward2 = reward:view(1000,1):expandAs(input) gradInput2:cmul(reward2):mul(-1) mytester:assertTensorEq(gradInput2, gradInput, 0.00001, "ReinforceCategorical backward err") @@ -878,6 +878,7 @@ function dpnntest.TotalDropout() end function dpnnbigtest.Reinforce() + error"this needs to be updated with new VRClassReward interface" -- let us try to reinforce an mlp to learn a simple distribution local n = 10 local inputs = torch.Tensor(n,3):uniform(0,0.1)