From cb227b497644af60a2d8148697fd9fb2c15ae5ae Mon Sep 17 00:00:00 2001 From: Nicholas Leonard Date: Wed, 7 Jun 2017 16:21:26 -0400 Subject: [PATCH] AbstractRecurrent:apply() fix --- AbstractRecurrent.lua | 6 ++++++ scripts/evaluate-rnnlm.lua | 1 + test/test.lua | 4 ++-- 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/AbstractRecurrent.lua b/AbstractRecurrent.lua index 3bbac87..ff36ab8 100644 --- a/AbstractRecurrent.lua +++ b/AbstractRecurrent.lua @@ -229,6 +229,12 @@ function AbstractRecurrent:evaluate() end) end +function AbstractRecurrent:apply(callback) + for k,stepmodule in pairs(self.sharedClones) do + stepmodule:apply(callback) + end +end + function AbstractRecurrent:reinforce(reward) local a = torch.Timer() if torch.type(reward) == 'table' then diff --git a/scripts/evaluate-rnnlm.lua b/scripts/evaluate-rnnlm.lua index 865ea2e..5b7e1c6 100644 --- a/scripts/evaluate-rnnlm.lua +++ b/scripts/evaluate-rnnlm.lua @@ -1,5 +1,6 @@ require 'nngraph' require 'rnn' +require 'optim' local dl = require 'dataload' diff --git a/test/test.lua b/test/test.lua index 660883e..591c24a 100644 --- a/test/test.lua +++ b/test/test.lua @@ -4348,8 +4348,8 @@ function rnntest.RecLSTM_maskzero() seqlstm2.bias:copy(reclstm.modules[1].bias) local input = torch.randn(T, N, D) - input[{2,1}]:fill(0) - input[{3,2}]:fill(0) + --input[{2,1}]:fill(0) + --input[{3,2}]:fill(0) local gradOutput = torch.randn(T, N, H) local zeroMask = torch.ByteTensor(T, N):zero()