Skip to content

Commit

Permalink
AbstractRecurrent:reinforce() handles multiple rewards
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholas-leonard committed Feb 25, 2016
1 parent e8940ed commit 417f8df
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 3 deletions.
16 changes: 13 additions & 3 deletions AbstractRecurrent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,19 @@ function AbstractRecurrent:evaluate()
end

function AbstractRecurrent:reinforce(reward)
return self:includingSharedClones(function()
return parent.reinforce(self, reward)
end)
if torch.type(reward) == 'table' then
-- multiple rewards, one per time-step
local rewards = reward
for step, reward in ipairs(rewards) do
local sm = self:getStepModule(step)
sm:reinforce(reward)
end
else
-- one reward broadcast to all time-steps
return self:includingSharedClones(function()
return parent.reinforce(self, reward)
end)
end
end

-- used by Recursor() after calling stepClone.
Expand Down
20 changes: 20 additions & 0 deletions Recurrent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,26 @@ function Recurrent:includingSharedClones(f)
return r
end

function Recurrent:reinforce(reward)
if torch.type(reward) == 'table' then
-- multiple rewards, one per time-step
local rewards = reward
for step, reward in ipairs(rewards) do
if step == 1 then
self.initialModule:reinforce(reward)
else
local sm = self:getStepModule(step)
sm:reinforce(reward)
end
end
else
-- one reward broadcast to all time-steps
return self:includingSharedClones(function()
return parent.reinforce(self, reward)
end)
end
end

function Recurrent:maskZero()
error("Recurrent doesn't support maskZero as it uses a different "..
"module for the first time-step. Use nn.Recurrence instead.")
Expand Down
17 changes: 17 additions & 0 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4047,6 +4047,23 @@ function rnntest.encoderdecoder()
numgradtest()
end

function rnntest.reinforce()
-- test that AbstractRecurrent:reinforce(rewards) words
local seqLen = 4
local batchSize = 3
local rewards = {}
for i=1,seqLen do
rewards[i] = torch.randn(batchSize)
end
local rf = nn.ReinforceNormal(0.1)
local rnn = nn.Recursor(rf)
rnn:reinforce(rewards)
for i=1,seqLen do
local rm = rnn:getStepModule(i)
mytester:assertTensorEq(rm.reward, rewards[i], 0.000001, "Reinforce error")
end
end

function rnn.test(tests, benchmark_)
mytester = torch.Tester()
benchmark = benchmark_
Expand Down

0 comments on commit 417f8df

Please sign in to comment.