Skip to content

Commit

Permalink
fix division by zero error
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Oct 13, 2015
1 parent 0d3283b commit f023676
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 2 deletions.
8 changes: 7 additions & 1 deletion ReinforceCategorical.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ function ReinforceCategorical:updateOutput(input)
self._index = self._index or ((torch.type(input) == 'torch.CudaTensor') and torch.CudaTensor() or torch.LongTensor())
if self.stochastic or self.train ~= false then
-- sample from categorical with p = input
self._input = self._input or input.new()
-- prevent division by zero error (see updateGradInput)
self._input:resizeAs(input):copy(input):add(0.00000001)
input.multinomial(self._index, input, 1)
-- one hot encoding
self.output:zero()
Expand All @@ -36,7 +39,10 @@ function ReinforceCategorical:updateGradInput(input, gradOutput)
-- d p 0 otherwise
self.gradInput:resizeAs(input):zero()
self.gradInput:copy(self.output)
self.gradInput:cdiv(input)
self._input = self._input or input.new()
-- prevent division by zero error
self._input:resizeAs(input):copy(input):add(0.00000001)
self.gradInput:cdiv(self._input)

-- multiply by reward
self.gradInput:cmul(self:rewardAs(input))
Expand Down
3 changes: 2 additions & 1 deletion test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -644,7 +644,7 @@ function dpnntest.ReinforceCategorical()
rc:reinforce(reward)
local gradInput = rc:updateGradInput(input, gradOutput)
local gradInput2 = output:clone()
gradInput2:cdiv(input)
gradInput2:cdiv(input+0.00000001)
local reward2 = reward:view(1000,1):expandAs(input)
gradInput2:cmul(reward2):mul(-1)
mytester:assertTensorEq(gradInput2, gradInput, 0.00001, "ReinforceCategorical backward err")
Expand Down Expand Up @@ -878,6 +878,7 @@ function dpnntest.TotalDropout()
end

function dpnnbigtest.Reinforce()
error"this needs to be updated with new VRClassReward interface"
-- let us try to reinforce an mlp to learn a simple distribution
local n = 10
local inputs = torch.Tensor(n,3):uniform(0,0.1)
Expand Down

0 comments on commit f023676

Please sign in to comment.