diff --git a/SeqGRU.lua b/SeqGRU.lua index 6b772f6..9cc4ac9 100644 --- a/SeqGRU.lua +++ b/SeqGRU.lua @@ -97,7 +97,12 @@ function SeqGRU:updateOutput(input) local h = self.output h:resize(seqlen, batchsize, outputsize):zero() - self.gates:resize(seqlen, batchsize, 3 * outputsize):zero() + + local nElement = self.gates:nElement() + self.gates:resize(seqlen, batchsize, 3 * outputsize) + if nElement ~= seqlen * batchsize * 3 * outputsize then + self.gates:zero() + end local prev_h = h0 if input.nn and input.nn.StepGRU_updateOutput and not self.forceLua then @@ -184,7 +189,8 @@ function SeqGRU:backward(input, gradOutput, scale) local u = self.gates[{t, {}, {outputsize + 1, 2 * outputsize}}] local hc = self.gates[{t, {}, {2 * outputsize + 1, 3 * outputsize}}] - local grad_a = self.grad_a_buffer:resize(batchsize, 3 * outputsize):zero() + local grad_a = self.grad_a_buffer:resize(batchsize, 3 * outputsize) + local grad_ar = grad_a[{{}, {1, outputsize}}] local grad_au = grad_a[{{}, {outputsize + 1, 2 * outputsize}}] local grad_ahc = grad_a[{{}, {2 * outputsize + 1, 3 * outputsize}}] @@ -192,7 +198,7 @@ function SeqGRU:backward(input, gradOutput, scale) -- use grad_au as temporary buffer to compute grad_ahc. local grad_hc = grad_au:fill(0):addcmul(grad_next_h, -1, u, grad_next_h) - grad_ahc:fill(1):addcmul(-1, hc,hc):cmul(grad_hc) + grad_ahc:fill(1):addcmul(-1, hc, hc):cmul(grad_hc) local grad_r = grad_au:fill(0):addmm(grad_ahc, Wh[{{}, {2 * outputsize + 1, 3 * outputsize}}]:t() ):cmul(prev_h) grad_ar:fill(1):add(-1, r):cmul(r):cmul(grad_r) diff --git a/SeqLSTM.lua b/SeqLSTM.lua index 36cefca..2f54e54 100644 --- a/SeqLSTM.lua +++ b/SeqLSTM.lua @@ -145,7 +145,12 @@ function SeqLSTM:updateOutput(input) local h, c = self.output, self.cell h:resize(seqlen, batchsize, outputsize) c:resize(seqlen, batchsize, hiddensize) + + local nElement = self.gates:nElement() self.gates:resize(seqlen, batchsize, 4 * hiddensize) + if nElement ~= seqlen * batchsize * 4 * hiddensize then + self.gates:zero() + end local prev_h, prev_c = h0, c0 if input.nn and input.nn.StepLSTM_updateOutput and not self.forceLua then diff --git a/StepGRU.lua b/StepGRU.lua index ea50c71..b149db4 100644 --- a/StepGRU.lua +++ b/StepGRU.lua @@ -52,8 +52,12 @@ function StepGRU:updateOutput(input) local Wh = self.weight:narrow(1, inputsize + 1, self.outputsize) next_h:resize(batchsize, outputsize) - self.gates:resize(batchsize, 3 * outputsize):zero() local gates = self.gates + local nElement = gates:nElement() + gates:resize(batchsize, 3 * outputsize) + if gates:nElement() ~= batchsize * 3 * outputsize then + gates:zero() + end gates:addmm(bias_expand, cur_x, Wx) local sub_gates = gates:narrow(2, 1, 2 * outputsize) @@ -92,7 +96,6 @@ function StepGRU:backward(input, gradOutput, scale) scale = scale or 1.0 assert(scale == 1.0, 'must have scale=1') - -- local grad_gates = torch.getBuffer('StepGRU', 'grad_gates', self.gates) -- batchsize x 3*outputsize local buffer = torch.getBuffer('StepGRU', 'buffer', self.gates) -- 1 x 3*outputsize @@ -125,7 +128,8 @@ function StepGRU:backward(input, gradOutput, scale) local update_gate = gates:narrow(2, outputsize + 1, outputsize) local hidden_candidate = gates:narrow(2, 2 * outputsize + 1, outputsize) - grad_gates:resize(batchsize, 3 * outputsize):zero() + grad_gates:resize(batchsize, 3 * outputsize) + local grad_reset_gate = grad_gates:narrow(2, 1, outputsize) local grad_update_gate = grad_gates:narrow(2, outputsize + 1, outputsize) local grad_hidden_candidate = grad_gates:narrow(2, 2 * outputsize + 1, outputsize) diff --git a/StepLSTM.lua b/StepLSTM.lua index 7285105..f3d4b1b 100644 --- a/StepLSTM.lua +++ b/StepLSTM.lua @@ -82,8 +82,12 @@ function StepLSTM:updateOutput(input) next_h:resize(batchsize, hiddensize) next_c:resize(batchsize, hiddensize) - self.gates:resize(batchsize, 4 * hiddensize):zero() local gates = self.gates + local nElement = gates:nElement() + gates:resize(batchsize, 4 * hiddensize) + if gates:nElement() ~= batchsize * 4 * hiddensize then + gates:zero() + end -- forward gates:addmm(bias_expand, cur_x, Wx) @@ -182,7 +186,8 @@ function StepLSTM:backward(input, gradOutput, scale) local output_gate = gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}] local input_transform = gates[{{}, {3 * hiddensize + 1, 4 * hiddensize}}] - grad_gates:resize(batchsize, 4 * hiddensize):zero() + grad_gates:resize(batchsize, 4 * hiddensize) + local grad_input_gate = grad_gates[{{}, {1, hiddensize}}] local grad_forget_gate = grad_gates[{{}, {hiddensize + 1, 2 * hiddensize}}] local grad_output_gate = grad_gates[{{}, {2 * hiddensize + 1, 3 * hiddensize}}] diff --git a/benchmark/README.md b/benchmark/README.md index 1f7a046..eac7edd 100644 --- a/benchmark/README.md +++ b/benchmark/README.md @@ -1,4 +1,4 @@ # Benchmark -On CPU, using Ubuntu 16.04, using float32, Torch LSTM boasts 886 samples/sec compared to TF’s 809 samples/sec for LSTM with 512 hiddensize and 64 batchsize. -On the other hand, for 128 hiddensize and 32 batchsize, Torch has 3950 compared to TF’s 4130 samples/sec. \ No newline at end of file +On CPU, using Ubuntu 16.04, using float32, Torch LSTM boasts 900 samples/sec compared to TF’s 809 samples/sec for LSTM with 512 hiddensize and 64 batchsize. +On the other hand, for 128 hiddensize and 32 batchsize, Torch has 3990 compared to TF’s 4130 samples/sec. \ No newline at end of file diff --git a/generic/StepGRU.c b/generic/StepGRU.c index e17a5ad..93da17a 100644 --- a/generic/StepGRU.c +++ b/generic/StepGRU.c @@ -22,7 +22,10 @@ static int nn_(StepGRU_updateOutput)(lua_State *L) { buffer->size[0] = batchsize; THTensor_(resize2d)(next_h, batchsize, outputsize); + long nElement = THTensor_(nElement)(gates); THTensor_(resize2d)(gates, batchsize, 3 * outputsize); + if (nElement != batchsize * 3 * outputsize) + THTensor_(fill)(gates, 0); THTensor *Wx = THTensor_(newNarrow)(weight, 0, 0, inputsize); THTensor *Wh = THTensor_(newNarrow)(weight, 0, inputsize, outputsize); @@ -32,8 +35,6 @@ static int nn_(StepGRU_updateOutput)(lua_State *L) { THTensor *update_gate = THTensor_(newNarrow)(gates, 1, outputsize, outputsize); // u = sig(Wx * x + Wh * prev_h + b) THTensor *hidden_candidate = THTensor_(newNarrow)(gates, 1, 2*outputsize, outputsize); // hc = tanh(Wx * x + Wh * r . prev_h + b) - //THTensor_(fill)(gates, 0); - // forward THTensor_(addmm)(gates, 1, buffer, 1, cur_x, Wx); THTensor_(addmm)(sub_gates, 1, sub_gates, 1, prev_h, sub_Wh); @@ -84,7 +85,6 @@ static int nn_(StepGRU_backward)(lua_State *L) { THTensor_(resize2d)(grad_cur_x, batchsize, inputsize); THTensor_(resize2d)(grad_prev_h, batchsize, outputsize); THTensor_(resize2d)(grad_gates, batchsize, 3 * outputsize); - THTensor_(fill)(grad_gates, 0); THTensor *Wx = THTensor_(newNarrow)(weight, 0, 0, inputsize); THTensor *Wh = THTensor_(newNarrow)(weight, 0, inputsize, outputsize); diff --git a/generic/StepLSTM.c b/generic/StepLSTM.c index 3e46d6f..d64449e 100644 --- a/generic/StepLSTM.c +++ b/generic/StepLSTM.c @@ -29,9 +29,10 @@ static int nn_(StepLSTM_updateOutput)(lua_State *L) { THTensor_(resize2d)(next_h, batchsize, hiddensize); THTensor_(resize2d)(next_c, batchsize, hiddensize); - + long nElement = THTensor_(nElement)(gates); THTensor_(resize2d)(gates, batchsize, 4 * hiddensize); - //THTensor_(fill)(gates, 0); + if (nElement != batchsize * 4 * hiddensize) + THTensor_(fill)(gates, 0); // forward THTensor_(addmm)(gates, 1, buffer, 1, cur_x, Wx); @@ -147,7 +148,6 @@ static int nn_(StepLSTM_backward)(lua_State *L) { THTensor *grad_Wh = THTensor_(newNarrow)(gradWeight, 0, inputsize, outputsize); THTensor_(resize2d)(grad_gates, batchsize, 4 * hiddensize); - THTensor_(fill)(grad_gates, 0); THTensor *grad_input_gate = THTensor_(newNarrow)(grad_gates, 1, 0, hiddensize); THTensor *grad_forget_gate = THTensor_(newNarrow)(grad_gates, 1, hiddensize, hiddensize); diff --git a/test/README.md b/test/README.md index b104498..49d8142 100644 --- a/test/README.md +++ b/test/README.md @@ -36,4 +36,48 @@ fast LSTM memory: 98.27904510498:2.2750110626221 MB step LSTM memory: 17.168065071106:2.1289348602295 MB rec LSTM memory: 13.374607086182:2.0407600402832 MB seq LSTM memory: 8.8895826339722:3.0098876953125 MB +``` + +More optimizations + +``` +th -lrnn -e 'rnn.bigtest({"LSTM","GRU"})' +Running 3 tests +1/3 LSTM_char_rnn ....................................................... [PASS] +2/3 GRU ................................................................. [WAIT]CPU test +old GRU time: 0.039725697040558 seconds +step GRU time: 0.014464259147644 seconds +luarec GRU time: 0.017707204818726 seconds +rec GRU time: 0.013900947570801 seconds +luaseq GRU time: 0.016570293903351 seconds +seq GRU time: 0.012663447856903 seconds +RecGRU-C 1.2738127907136 faster than RecGRU-Lua +RecGRU 2.8577690001509 faster than old GRU +SeqGRU 1.0977221786579 faster than RecGRU +SeqGRU-C 1.3085136126113 faster than SeqGRU-Lua +Memory test +old GRU memory: 82.804834365845:1.833381652832 MB +step GRU memory: 10.018351554871:1.5651426315308 MB +rec GRU memory: 10.018255233765:1.5337238311768 MB +seq GRU memory: 6.3827362060547:1.5385322570801 MB +2/3 GRU ................................................................. [PASS] +3/3 LSTM ................................................................ [WAIT]CPU test +fast LSTM time: 0.044381546974182 seconds +step LSTM time: 0.021313452720642 seconds +luarec LSTM time: 0.021889448165894 seconds +rec LSTM time: 0.017923295497894 seconds +luaseq LSTM time: 0.018705642223358 seconds +seq LSTM time: 0.016467046737671 seconds +RecLSTM-C 1.2212847893104 faster than RecLSTM-Lua +RecLSTM 2.476193453341 faster than FastLSTM +SeqLSTM 1.0884341183591 faster than RecLSTM +SeqLSTM-C 1.1359439565181 faster than SeqLSTM-Lua +Memory test +fast LSTM memory: 98.2790184021:2.2749843597412 MB +step LSTM memory: 17.168484687805:2.1293544769287 MB +rec LSTM memory: 13.375099182129:2.0412521362305 MB +seq LSTM memory: 8.8264684677124:2.0093183517456 MB +3/3 LSTM ................................................................ [PASS] +Completed 0 asserts in 3 tests with 0 failures and 0 errors and 1 warning +-------------------------------------------------------------------------------- ``` \ No newline at end of file