Skip to content

Commit

Permalink
fix merge
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Nov 25, 2016
2 parents 2267dc7 + 5b2eb7f commit c97f143
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 13 deletions.
14 changes: 12 additions & 2 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
moduleTree = obj
obj = nil
isTable = false
elseif obj.dpnn_sharedClone then
-- allow to use a custom sharedClone method on one module
moduleTree = obj
obj = nil
isTable = false
elseif scdone[torch.pointer(obj)] then
moduleTree = scdone[torch.pointer(obj)]
else
Expand Down Expand Up @@ -142,8 +147,13 @@ function Module:sharedClone(shareParams, shareGradParams, stepClone)
if scdone[torch.pointer(original)] then
for k,param in pairs(moduleTree) do
if torch.isTypeOf(param,'nn.Module') then
-- AbstractRecurrent instances branch here with stepClone = true
clone[k] = param
if param.dpnn_sharedClone then
-- Call the custom sharedClone
clone[k] = param:dpnn_sharedClone()
else
-- AbstractRecurrent instances branch here with stepClone = true
clone[k] = param
end
original[k] = param
elseif torch.isTensor(param) then
if param.storage then
Expand Down
21 changes: 16 additions & 5 deletions VRClassReward.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,28 @@ function VRClassReward:updateOutput(input, target)
assert(torch.type(input) == 'table')
local input = self:toBatch(input[1], 1)
self._maxVal = self._maxVal or input.new()
self._maxIdx = self._maxIdx or torch.type(input) == 'torch.CudaTensor' and input.new() or torch.LongTensor()
self._maxIdx = self._maxIdx or torch.type(input) == 'torch.CudaTensor' and torch.CudaLongTensor() or torch.LongTensor()

-- max class value is class prediction
self._maxVal:max(self._maxIdx, input, 2)
if torch.type(self._maxIdx) ~= torch.type(target) then
self._target = self._target or self._maxIdx.new()

-- reward = scale when correctly classified
local maxIdx = self._maxIdx
if torch.type(self._maxIdx) == 'torch.CudaLongTensor' then
self.__maxIdx = self.__maxIdx or torch.CudaTensor()
self.__maxIdx:resize(maxIdx:size()):copy(maxIdx)
maxIdx = self.__maxIdx
end

if torch.type(maxIdx) ~= torch.type(target) then
self._target = self._target or maxIdx.new()
self._target:resize(target:size()):copy(target)
target = self._target
end

-- reward = scale when correctly classified
self._reward = self._reward or self._maxIdx.new()
self._reward:eq(self._maxIdx, target)
self._reward = self._reward or maxIdx.new()
self._reward:eq(maxIdx, target)
self.reward = self.reward or input.new()
self.reward:resize(self._reward:size(1)):copy(self._reward)
self.reward:mul(self.scale)
Expand Down Expand Up @@ -66,6 +75,7 @@ function VRClassReward:updateGradInput(inputTable, target)
self.gradInput[1] = self:fromBatch(self.gradInput[1], 1)

-- learn the baseline reward
self.criterion:forward(baseline, self.reward)
self.gradInput[2] = self.criterion:backward(baseline, self.reward)
self.gradInput[2] = self:fromBatch(self.gradInput[2], 1)
return self.gradInput
Expand All @@ -74,6 +84,7 @@ end
function VRClassReward:type(type)
self._maxVal = nil
self._maxIdx = nil
self.__maxIdx = nil
self._target = nil
local module = self.module
self.module = nil
Expand Down
27 changes: 21 additions & 6 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -806,19 +806,34 @@ function dpnntest.ReinforceCategorical()
end

function dpnntest.VRClassReward()
local input = {torch.randn(13,10), torch.randn(13,1)}
local input = {torch.randn(13,10):float(), torch.randn(13,1):float()}
local target = torch.IntTensor(13):random(1,10)
local rf = nn.Reinforce()
local vrc = nn.VRClassReward(rf)
local rf = nn.Reinforce():float()
local vrc = nn.VRClassReward(rf):float()
local err = vrc:forward(input, target)
local gradInput = vrc:backward(input, target)
local val, idx = input[1]:max(2)
local reward = torch.eq(idx:select(2,1):int(), target):double()
local reward = torch.eq(idx:select(2,1):int(), target):float()
local err2 = -reward:mean()
mytester:assert(err == err2, "VRClassReward forward err")
local gradInput2 = nn.MSECriterion():backward(input[2], reward)
local gradInput2 = nn.MSECriterion():float():backward(input[2], reward)
mytester:assertTensorEq(gradInput[2], gradInput2, 0.000001, "VRClassReward backward baseline err")
mytester:assertTensorEq(gradInput[1], input[1]:zero(), 0.000001, "VRClassReward backward class err")
mytester:assert(math.abs(gradInput[1]:sum()) < 0.000001, "VRClassReward backward class err")

if pcall(function() require 'cunn' end) then
local gradInput = {gradInput[1], gradInput[2]}
input[1], input[2] = input[1]:cuda(), input[2]:cuda()
target = target:cuda()
rf:cuda()
vrc:cuda()

local err2 = vrc:forward(input, target)
local gradInput2 = vrc:backward(input, target)

mytester:assert(math.abs(err - err2) < 0.000001, "VRClassReward forward cuda err")
mytester:assertTensorEq(gradInput[2], gradInput2[2]:float(), 0.000001, "VRClassReward backward baseline cuda err")
mytester:assertTensorEq(gradInput[1], gradInput2[1]:float(), 0.000001, "VRClassReward backward class cuda err")
end
end

function dpnntest.BinaryClassReward()
Expand Down

0 comments on commit c97f143

Please sign in to comment.