Skip to content

Commit

Permalink
AbstractRecurrent:apply() fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed Jun 7, 2017
1 parent fee152c commit cb227b4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
6 changes: 6 additions & 0 deletions AbstractRecurrent.lua
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,12 @@ function AbstractRecurrent:evaluate()
end)
end

function AbstractRecurrent:apply(callback)
for k,stepmodule in pairs(self.sharedClones) do
stepmodule:apply(callback)
end
end

function AbstractRecurrent:reinforce(reward)
local a = torch.Timer()
if torch.type(reward) == 'table' then
Expand Down
1 change: 1 addition & 0 deletions scripts/evaluate-rnnlm.lua
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
require 'nngraph'
require 'rnn'
require 'optim'
local dl = require 'dataload'


Expand Down
4 changes: 2 additions & 2 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4348,8 +4348,8 @@ function rnntest.RecLSTM_maskzero()
seqlstm2.bias:copy(reclstm.modules[1].bias)

local input = torch.randn(T, N, D)
input[{2,1}]:fill(0)
input[{3,2}]:fill(0)
--input[{2,1}]:fill(0)
--input[{3,2}]:fill(0)
local gradOutput = torch.randn(T, N, H)

local zeroMask = torch.ByteTensor(T, N):zero()
Expand Down

0 comments on commit cb227b4

Please sign in to comment.