Skip to content

Commit

Permalink
GRU and LSTM optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholas Leonard committed May 12, 2017
1 parent 71d80f8 commit e5508b7
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 16 deletions.
12 changes: 9 additions & 3 deletions SeqGRU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,12 @@ function SeqGRU:updateOutput(input)

local h = self.output
h:resize(seqlen, batchsize, outputsize):zero()
self.gates:resize(seqlen, batchsize, 3 * outputsize):zero()

local nElement = self.gates:nElement()
self.gates:resize(seqlen, batchsize, 3 * outputsize)
if nElement ~= seqlen * batchsize * 3 * outputsize then
self.gates:zero()
end

local prev_h = h0
if input.nn and input.nn.StepGRU_updateOutput and not self.forceLua then
Expand Down Expand Up @@ -184,15 +189,16 @@ function SeqGRU:backward(input, gradOutput, scale)
local u = self.gates[{t, {}, {outputsize + 1, 2 * outputsize}}]
local hc = self.gates[{t, {}, {2 * outputsize + 1, 3 * outputsize}}]

local grad_a = self.grad_a_buffer:resize(batchsize, 3 * outputsize):zero()
local grad_a = self.grad_a_buffer:resize(batchsize, 3 * outputsize)

local grad_ar = grad_a[{{}, {1, outputsize}}]
local grad_au = grad_a[{{}, {outputsize + 1, 2 * outputsize}}]
local grad_ahc = grad_a[{{}, {2 * outputsize + 1, 3 * outputsize}}]

-- use grad_au as temporary buffer to compute grad_ahc.

local grad_hc = grad_au:fill(0):addcmul(grad_next_h, -1, u, grad_next_h)
grad_ahc:fill(1):addcmul(-1, hc,hc):cmul(grad_hc)
grad_ahc:fill(1):addcmul(-1, hc, hc):cmul(grad_hc)
local grad_r = grad_au:fill(0):addmm(grad_ahc, Wh[{{}, {2 * outputsize + 1, 3 * outputsize}}]:t() ):cmul(prev_h)
grad_ar:fill(1):add(-1, r):cmul(r):cmul(grad_r)

Expand Down
5 changes: 5 additions & 0 deletions SeqLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,12 @@ function SeqLSTM:updateOutput(input)
local h, c = self.output, self.cell
h:resize(seqlen, batchsize, outputsize)
c:resize(seqlen, batchsize, hiddensize)

local nElement = self.gates:nElement()
self.gates:resize(seqlen, batchsize, 4 * hiddensize)
if nElement ~= seqlen * batchsize * 4 * hiddensize then
self.gates:zero()
end

local prev_h, prev_c = h0, c0
if input.nn and input.nn.StepLSTM_updateOutput and not self.forceLua then
Expand Down
10 changes: 7 additions & 3 deletions StepGRU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ function StepGRU:updateOutput(input)
local Wh = self.weight:narrow(1, inputsize + 1, self.outputsize)

next_h:resize(batchsize, outputsize)
self.gates:resize(batchsize, 3 * outputsize):zero()
local gates = self.gates
local nElement = gates:nElement()
gates:resize(batchsize, 3 * outputsize)
if gates:nElement() ~= batchsize * 3 * outputsize then
gates:zero()
end

gates:addmm(bias_expand, cur_x, Wx)
local sub_gates = gates:narrow(2, 1, 2 * outputsize)
Expand Down Expand Up @@ -92,7 +96,6 @@ function StepGRU:backward(input, gradOutput, scale)
scale = scale or 1.0
assert(scale == 1.0, 'must have scale=1')

--
local grad_gates = torch.getBuffer('StepGRU', 'grad_gates', self.gates) -- batchsize x 3*outputsize
local buffer = torch.getBuffer('StepGRU', 'buffer', self.gates) -- 1 x 3*outputsize

Expand Down Expand Up @@ -125,7 +128,8 @@ function StepGRU:backward(input, gradOutput, scale)
local update_gate = gates:narrow(2, outputsize + 1, outputsize)
local hidden_candidate = gates:narrow(2, 2 * outputsize + 1, outputsize)

grad_gates:resize(batchsize, 3 * outputsize):zero()
grad_gates:resize(batchsize, 3 * outputsize)

local grad_reset_gate = grad_gates:narrow(2, 1, outputsize)
local grad_update_gate = grad_gates:narrow(2, outputsize + 1, outputsize)
local grad_hidden_candidate = grad_gates:narrow(2, 2 * outputsize + 1, outputsize)
Expand Down
9 changes: 7 additions & 2 deletions StepLSTM.lua
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,12 @@ function StepLSTM:updateOutput(input)
next_h:resize(batchsize, hiddensize)
next_c:resize(batchsize, hiddensize)

self.gates:resize(batchsize, 4 * hiddensize):zero()
local gates = self.gates
local nElement = gates:nElement()
gates:resize(batchsize, 4 * hiddensize)
if gates:nElement() ~= batchsize * 4 * hiddensize then
gates:zero()
end

-- forward
gates:addmm(bias_expand, cur_x, Wx)
Expand Down Expand Up @@ -182,7 +186,8 @@ function StepLSTM:backward(input, gradOutput, scale)
local output_gate = gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}]
local input_transform = gates[{{}, {3 * hiddensize + 1, 4 * hiddensize}}]

grad_gates:resize(batchsize, 4 * hiddensize):zero()
grad_gates:resize(batchsize, 4 * hiddensize)

local grad_input_gate = grad_gates[{{}, {1, hiddensize}}]
local grad_forget_gate = grad_gates[{{}, {hiddensize + 1, 2 * hiddensize}}]
local grad_output_gate = grad_gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}]
Expand Down
4 changes: 2 additions & 2 deletions benchmark/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Benchmark

On CPU, using Ubuntu 16.04, using float32, Torch LSTM boasts 886 samples/sec compared to TF’s 809 samples/sec for LSTM with 512 hiddensize and 64 batchsize.
On the other hand, for 128 hiddensize and 32 batchsize, Torch has 3950 compared to TF’s 4130 samples/sec.
On CPU, using Ubuntu 16.04, using float32, Torch LSTM boasts 900 samples/sec compared to TF’s 809 samples/sec for LSTM with 512 hiddensize and 64 batchsize.
On the other hand, for 128 hiddensize and 32 batchsize, Torch has 3990 compared to TF’s 4130 samples/sec.
6 changes: 3 additions & 3 deletions generic/StepGRU.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ static int nn_(StepGRU_updateOutput)(lua_State *L) {
buffer->size[0] = batchsize;

THTensor_(resize2d)(next_h, batchsize, outputsize);
long nElement = THTensor_(nElement)(gates);
THTensor_(resize2d)(gates, batchsize, 3 * outputsize);
if (nElement != batchsize * 3 * outputsize)
THTensor_(fill)(gates, 0);

THTensor *Wx = THTensor_(newNarrow)(weight, 0, 0, inputsize);
THTensor *Wh = THTensor_(newNarrow)(weight, 0, inputsize, outputsize);
Expand All @@ -32,8 +35,6 @@ static int nn_(StepGRU_updateOutput)(lua_State *L) {
THTensor *update_gate = THTensor_(newNarrow)(gates, 1, outputsize, outputsize); // u = sig(Wx * x + Wh * prev_h + b)
THTensor *hidden_candidate = THTensor_(newNarrow)(gates, 1, 2*outputsize, outputsize); // hc = tanh(Wx * x + Wh * r . prev_h + b)

//THTensor_(fill)(gates, 0);

// forward
THTensor_(addmm)(gates, 1, buffer, 1, cur_x, Wx);
THTensor_(addmm)(sub_gates, 1, sub_gates, 1, prev_h, sub_Wh);
Expand Down Expand Up @@ -84,7 +85,6 @@ static int nn_(StepGRU_backward)(lua_State *L) {
THTensor_(resize2d)(grad_cur_x, batchsize, inputsize);
THTensor_(resize2d)(grad_prev_h, batchsize, outputsize);
THTensor_(resize2d)(grad_gates, batchsize, 3 * outputsize);
THTensor_(fill)(grad_gates, 0);

THTensor *Wx = THTensor_(newNarrow)(weight, 0, 0, inputsize);
THTensor *Wh = THTensor_(newNarrow)(weight, 0, inputsize, outputsize);
Expand Down
6 changes: 3 additions & 3 deletions generic/StepLSTM.c
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@ static int nn_(StepLSTM_updateOutput)(lua_State *L) {

THTensor_(resize2d)(next_h, batchsize, hiddensize);
THTensor_(resize2d)(next_c, batchsize, hiddensize);

long nElement = THTensor_(nElement)(gates);
THTensor_(resize2d)(gates, batchsize, 4 * hiddensize);
//THTensor_(fill)(gates, 0);
if (nElement != batchsize * 4 * hiddensize)
THTensor_(fill)(gates, 0);

// forward
THTensor_(addmm)(gates, 1, buffer, 1, cur_x, Wx);
Expand Down Expand Up @@ -147,7 +148,6 @@ static int nn_(StepLSTM_backward)(lua_State *L) {
THTensor *grad_Wh = THTensor_(newNarrow)(gradWeight, 0, inputsize, outputsize);

THTensor_(resize2d)(grad_gates, batchsize, 4 * hiddensize);
THTensor_(fill)(grad_gates, 0);

THTensor *grad_input_gate = THTensor_(newNarrow)(grad_gates, 1, 0, hiddensize);
THTensor *grad_forget_gate = THTensor_(newNarrow)(grad_gates, 1, hiddensize, hiddensize);
Expand Down
44 changes: 44 additions & 0 deletions test/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,48 @@ fast LSTM memory: 98.27904510498:2.2750110626221 MB
step LSTM memory: 17.168065071106:2.1289348602295 MB
rec LSTM memory: 13.374607086182:2.0407600402832 MB
seq LSTM memory: 8.8895826339722:3.0098876953125 MB
```

More optimizations

```
th -lrnn -e 'rnn.bigtest({"LSTM","GRU"})'
Running 3 tests
1/3 LSTM_char_rnn ....................................................... [PASS]
2/3 GRU ................................................................. [WAIT]CPU test
old GRU time: 0.039725697040558 seconds
step GRU time: 0.014464259147644 seconds
luarec GRU time: 0.017707204818726 seconds
rec GRU time: 0.013900947570801 seconds
luaseq GRU time: 0.016570293903351 seconds
seq GRU time: 0.012663447856903 seconds
RecGRU-C 1.2738127907136 faster than RecGRU-Lua
RecGRU 2.8577690001509 faster than old GRU
SeqGRU 1.0977221786579 faster than RecGRU
SeqGRU-C 1.3085136126113 faster than SeqGRU-Lua
Memory test
old GRU memory: 82.804834365845:1.833381652832 MB
step GRU memory: 10.018351554871:1.5651426315308 MB
rec GRU memory: 10.018255233765:1.5337238311768 MB
seq GRU memory: 6.3827362060547:1.5385322570801 MB
2/3 GRU ................................................................. [PASS]
3/3 LSTM ................................................................ [WAIT]CPU test
fast LSTM time: 0.044381546974182 seconds
step LSTM time: 0.021313452720642 seconds
luarec LSTM time: 0.021889448165894 seconds
rec LSTM time: 0.017923295497894 seconds
luaseq LSTM time: 0.018705642223358 seconds
seq LSTM time: 0.016467046737671 seconds
RecLSTM-C 1.2212847893104 faster than RecLSTM-Lua
RecLSTM 2.476193453341 faster than FastLSTM
SeqLSTM 1.0884341183591 faster than RecLSTM
SeqLSTM-C 1.1359439565181 faster than SeqLSTM-Lua
Memory test
fast LSTM memory: 98.2790184021:2.2749843597412 MB
step LSTM memory: 17.168484687805:2.1293544769287 MB
rec LSTM memory: 13.375099182129:2.0412521362305 MB
seq LSTM memory: 8.8264684677124:2.0093183517456 MB
3/3 LSTM ................................................................ [PASS]
Completed 0 asserts in 3 tests with 0 failures and 0 errors and 1 warning
--------------------------------------------------------------------------------
```

0 comments on commit e5508b7

Please sign in to comment.