Skip to content

Commit

Permalink
stepClone + ReinforceNormal stdev fix
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Oct 27, 2015
1 parent 776a555 commit e011a4f
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
8 changes: 4 additions & 4 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Module.dpnn_gradParameters = {'gradWeight', 'gradBias'}
-- self.modules, self.sharedClones or as an attribute to self.
-- So if you store a module in self.mytbl = {mymodule}, it will be cloned
-- independently of sharedClone (i.e. deep copy).
function Module:sharedClone(shareParams, shareGradParams, clones, pointers)
function Module:sharedClone(shareParams, shareGradParams, clones, pointers, stepClone)
shareParams = (shareParams == nil) and true or shareParams
shareGradParams = (shareGradParams == nil) and true or shareGradParams

Expand All @@ -48,7 +48,7 @@ function Module:sharedClone(shareParams, shareGradParams, clones, pointers)
for i,module in ipairs(self.modules) do
local clone
if not clones[torch.pointer(module)] then
clone = module:sharedClone(shareParams, shareGradParams, clones, pointers)
clone = module:sharedClone(shareParams, shareGradParams, clones, pointers, stepClone)
clones[torch.pointer(module)] = clone
else
clone = clones[torch.pointer(module)]
Expand All @@ -66,7 +66,7 @@ function Module:sharedClone(shareParams, shareGradParams, clones, pointers)
for i,sharedClone in pairs(self.sharedClones) do
local clone
if not clones[torch.pointer(sharedClone)] then
clone = sharedClone:sharedClone(shareParams, shareGradParams, clones, pointers)
clone = sharedClone:sharedClone(shareParams, shareGradParams, clones, pointers, stepClone)
clones[torch.pointer(sharedClone)] = clone
else
clone = clones[torch.pointer(sharedClone)]
Expand All @@ -84,7 +84,7 @@ function Module:sharedClone(shareParams, shareGradParams, clones, pointers)

local clone
if not clones[torch.pointer(module)] then
clone = module:sharedClone(shareParams, shareGradParams, clones, pointers)
clone = module:sharedClone(shareParams, shareGradParams, clones, pointers, stepClone)
clones[torch.pointer(module)] = clone
else
clone = clones[torch.pointer(module)]
Expand Down
9 changes: 5 additions & 4 deletions ReinforceNormal.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,11 @@ function ReinforceNormal:updateGradInput(input, gradOutput)
gradMean:cdiv(self.__stdev):cdiv(self.__stdev)
end
end
-- multiply by reward
gradMean:cmul(self:rewardAs(mean))
-- multiply by reward
gradMean:cmul(self:rewardAs(mean) )
-- multiply by -1 ( gradient descent on mean )
gradMean:mul(-1)


-- Derivative of log normal w.r.t. stdev :
-- d ln(f(x,u,s)) (x - u)^2 - s^2
-- -------------- = ---------------
Expand All @@ -113,8 +112,10 @@ function ReinforceNormal:updateGradInput(input, gradOutput)
self._stdev2:resizeAs(stdev):copy(stdev):cmul(stdev)
gradStdev:add(-1, self._stdev2)
-- divide by s^3
self._stdev2:cmul(stdev)
self._stdev2:cmul(stdev):add(0.00000001)
gradStdev:cdiv(self._stdev2)
-- multiply by reward
gradStdev:cmul(self:rewardAs(stdev))
-- multiply by -1 ( gradient descent on stdev )
gradStdev:mul(-1)
end
Expand Down
6 changes: 4 additions & 2 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -602,8 +602,10 @@ function dpnntest.ReinforceNormal()
local gradStdev = output:clone():add(-1, mean):pow(2)
local stdev2 = torch.cmul(stdev,stdev)
gradStdev:add(-1,stdev2)
stdev2:cmul(stdev)
gradStdev:cdiv(stdev2):mul(-1)
stdev2:cmul(stdev):add(0.00000001)
gradStdev:cdiv(stdev2)
local reward2 = reward:view(4,1):expandAs(gradStdev)
gradStdev:cmul(reward2):mul(-1)
mytester:assertTensorEq(gradInput[2], gradStdev, 0.000001, "ReinforceNormal backward table input - gradStdev err")
end

Expand Down

0 comments on commit e011a4f

Please sign in to comment.