Skip to content

Commit

Permalink
ReverseTable+SeqReverseSequence -> ReverseSequnece
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed May 17, 2017
1 parent 6ca4c57 commit 72091ff
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 49 deletions.
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ SET(luasrc
SeqGRU.lua
SeqLSTM.lua
deprecated/SeqLSTMP.lua
SeqReverseSequence.lua
deprecated/SeqReverseSequence.lua
Sequencer.lua
SequencerCriterion.lua
ZeroGrad.lua
Expand Down Expand Up @@ -85,7 +85,7 @@ SET(luasrc
ReinforceCategorical.lua
ReinforceGamma.lua
ReinforceNormal.lua
ReverseTable.lua
ReverseSequence.lua
Sequential.lua
Serial.lua
SimpleColorTransform.lua
Expand Down
75 changes: 75 additions & 0 deletions ReverseSequence.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
local ReverseSequence, parent = torch.class("nn.ReverseSequence", "nn.Module")

function ReverseSequence:updateOutput(input)
local seqlen
if torch.isTensor(input) then
seqlen = input:size(1)
self.output = torch.isTensor(self.output) and self.output or input.new()
self.output:resizeAs(input)

self._range = self._range or torch.isCudaTensor(input) and torch.CudaLongTensor() or torch.LongTensor()
if self._range:nElement() ~= seqlen then
self._range:range(seqlen,1,-1)
end
self.output:index(input, 1, self._range)
else
seqlen = #input
self.output = torch.type(self.output) == 'table' and self.output or {}
assert(torch.type(input) == 'table', "Expecting table or tensor at arg 1")

-- empty output table
for k,v in ipairs(self.output) do
self.output[k] = nil
end

-- reverse input
local k = 1
for i=seqlen,1,-1 do
self.output[k] = input[i]
k = k + 1
end
end

return self.output
end

function ReverseSequence:updateGradInput(input, gradOutput)
local seqlen
if torch.isTensor(input) then
seqlen = input:size(1)
self.gradInput = torch.isTensor(self.gradInput) and self.gradInput or input.new()
self.gradInput:resizeAs(input)

self.gradInput:index(gradOutput, 1, self._range)
else
seqlen = #input
self.gradInput = torch.type(self.gradInput) == 'table' and self.gradInput or {}
assert(torch.type(gradOutput) == 'table', "Expecting table or tensor at arg 2")

-- empty gradInput table
for k,v in ipairs(self.gradInput) do
self.gradInput[k] = nil
end

-- reverse gradOutput
local k = 1
for i=seqlen,1,-1 do
self.gradInput[k] = gradOutput[i]
k = k + 1
end
end

return self.gradInput
end

function ReverseSequence:clearState()
self.gradInput = torch.Tensor()
self.output = torch.Tensor()
self._range = nil
end

function ReverseSequence:type(...)
self:clearState()
return parent.type(self, ...)
end

39 changes: 0 additions & 39 deletions ReverseTable.lua

This file was deleted.

File renamed without changes.
4 changes: 2 additions & 2 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ require('rnn.Collapse')
require('rnn.ZipTable')
require('rnn.ZipTableOneToMany')
require('rnn.CAddTensorTable')
require('rnn.ReverseTable')
require('rnn.ReverseSequence')
require('rnn.Dictionary')
require('rnn.Inception')
require('rnn.Clip')
Expand Down Expand Up @@ -131,7 +131,6 @@ require('rnn.RecurrentAttention')
-- sequencer + recurrent modules
require('rnn.SeqLSTM')
require('rnn.SeqGRU')
require('rnn.SeqReverseSequence')
require('rnn.SeqBRNN')

-- recurrent criterions:
Expand All @@ -144,6 +143,7 @@ require('rnn.MaskZeroCriterion')
require('rnn.LSTM')
require('rnn.FastLSTM')
require('rnn.SeqLSTMP')
require('rnn.SeqReverseSequence')

-- prevent likely name conflicts
nn.rnn = rnn
Expand Down
33 changes: 27 additions & 6 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -4991,24 +4991,45 @@ function rnntest.CAddTensorTable()
mytester:assertTensorEq(output[1]+output[2]+output[3], gradInput[1], 0.000001, "CAddTensorTable gradInput1")
end

function rnntest.ReverseTable()
function rnntest.ReverseSequence()
-- test table

-- input : { a, b, c, d }
-- output : { c, b, a, d }
local r = nn.ReverseTable()
local r = nn.ReverseSequence()
local input = {torch.randn(3,4), torch.randn(3,4), torch.randn(3,4), torch.randn(3,4)}
local output = r:forward(input)

mytester:assert(#output == 4, "ReverseTable #output")
mytester:assert(#output == 4, "ReverseSequence #output")
local k = 1
for i=#input,1,-1 do
mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseTable output err "..k)
mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseSequence output err "..k)
k = k + 1
end

local gradInput = r:backward(input, output)
mytester:assert(#gradInput == 4, "ReverseTable #gradInput")
mytester:assert(#gradInput == 4, "ReverseSequence #gradInput")
for i=1,#input do
mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseTable gradInput err "..i)
mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseSequence gradInput err "..i)
end

-- test tensor

local r = nn.ReverseSequence()
local input = torch.randn(5,4,3)
local output = r:forward(input)

mytester:assert(output:isSameSizeAs(input), "ReverseSequence #output")
local k = 1
for i=5,1,-1 do
mytester:assertTensorEq(input[i], output[k], 0.00001, "ReverseSequence output err "..k)
k = k + 1
end

local gradInput = r:backward(input, output)
mytester:assert(gradInput:isSameSizeAs(input), "ReverseSequence #gradInput")
for i=1,5 do
mytester:assertTensorEq(gradInput[i], input[i], 0.00001, "ReverseSequence gradInput err "..i)
end
end

Expand Down

0 comments on commit 72091ff

Please sign in to comment.