Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multi-function recurrent unit #344

Merged
merged 4 commits into from
Feb 17, 2017

Conversation

juesato
Copy link
Contributor

@juesato juesato commented Sep 24, 2016

Implementation of #269

A few notes:

  • The model seems to not train very well using simple SGD, or at least requires lower learning rates than typical LSTM/GRU networks. Using the same hyperparameters as the LSTM recurrent language model example, the model doesn't look like it's converging and training PPL doesn't drop below 300.
  • I added an Adam option to the language model script, since the original Mufuru paper also uses Adam, and this seems to address these issues. (I also checked that this preserves behavior when using SGD by running the LSTM example and making sure the stats are the same - converges to training PPL 50, valid PPL 120, and test PPL 116 after 7 epochs before and after the change.)
  • Right now, I have the max and min options turned off, because until I merge Make CMaxTable and CMinTable cunn-compatible torch/nn#954, the CMaxTable and CMinTable options aren't compatible with CudaTensors. However, this makes the model considerably worse, converging to validation PPL 164 rather than 134.
  • The relative performance of different models are similar to the results in the paper, but the absolute losses don't match up with the paper. I emailed Dirk Weissenborn about this, and he was kind enough to share his Tensorflow training scripts. In Tensorflow, both MFRU and GRU converged to about 144 validation PPL. In this Torch implementation, MFRU and GRU (implemented as MFRU) converge to about 134 validation PPL. GRU, however, converges to 121 on validation set, which may be due to the forget gate biases.
  • I have a spreadsheet summarizing the experiments I've run so far here: https://docs.google.com/spreadsheets/d/1snsSUCqv0frfjBztkkwONfiMT5kyKjzJQr-KI1F1qvE/edit?usp=sharing

I think this should merge after torch/nn#954 since until then the performance is poor, but I wanted to get feedback on this before adding unit tests and documentation.

Let me know if you have any questions!

@nicholas-leonard

@nicholas-leonard
Copy link
Member

@juesato @JoostvDoorn @jnhwkim can you review this? Minimum requirements: documentation (README.md), unit tests and code.

@jnhwkim
Copy link
Contributor

jnhwkim commented Sep 27, 2016

@nicholas-leonard For me, I should have to read the paper first. I'll revisit here to catch up.

@nicholas-leonard
Copy link
Member

@juesato I would really like to see this get merged. Any developments?

@juesato
Copy link
Contributor Author

juesato commented Jan 26, 2017

@nicholas-leonard Sorry, I left this hanging. I'l spend the next two hours working on this (speeding up the CMaxTable stuff, adding docs, and unit tests), and if I don't finish then, I'll continue tomorrow.

@juesato
Copy link
Contributor Author

juesato commented Jan 30, 2017

@nicholas-leonard I got some time to spend on this this weekend, but found a bug, and need a bit more time. I'm hoping to finish up after work tomorrow, but I think it's fairly likely I won't have time, in which case I'll finish on Tuesday.

@juesato
Copy link
Contributor Author

juesato commented Jan 31, 2017

@nicholas-leonard I added unit tests and documentation, and I believe this should be ready to merge.

As a future reference, I'm going to leave the training curves on PTB here.

Command (I added support for Adam, since that's what's used in the original paper):

CUDA_VISIBLE_DEVICES=0 th examples/recurrent-language-model.lua --progress --cuda --mfru --seqlen 20 --uniform 0.1 --hiddensize '{200}' --batchsize 32 --maxepoch 20 --device 1 --adam --startlr 0.01 --cutoff 5
Epoch #1 :	
 [======================================== 29048/29048 ================================>]  Tot: 2m18s | Step: 4ms       
learning rate	0.009975025	
mean gradParam norm	8.4984952035468	
Speed : 0.004788 sec/batch 	
Training PPL : 282.28724615675	
Validation PPL : 195.89575663903	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485858192:1.t7	
	
Epoch #2 :	
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate	0.00995005	
mean gradParam norm	9.6138208177836	
Speed : 0.004739 sec/batch 	
Training PPL : 138.24728917969	
Validation PPL : 152.95716046988	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485858192:1.t7	
	
Epoch #3 :	
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate	0.009925075	
mean gradParam norm	10.585762301436	
Speed : 0.004736 sec/batch 	
Training PPL : 99.841825199886	
Validation PPL : 138.14327248568	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485858192:1.t7	
	
Epoch #4 :	
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate	0.0099001	
mean gradParam norm	11.394943580532	
Speed : 0.004733 sec/batch 	
Training PPL : 78.723720426491	
Validation PPL : 133.5123597645	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485858192:1.t7	
	
Epoch #5 :	
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate	0.009875125	
mean gradParam norm	12.078977865072	
Speed : 0.004733 sec/batch 	
Training PPL : 65.066691855094	
Validation PPL : 134.65883864519	
	
Epoch #6 :	
 [======================================== 29048/29048 ================================>]  Tot: 2m17s | Step: 4ms       
learning rate	0.00985015	
mean gradParam norm	12.725478478451	
Speed : 0.004731 sec/batch 	
Training PPL : 55.300139163342	
Validation PPL : 139.33760207709	

As a baseline, if we swap out GRU for MuFuRU, we get this curve

Epoch #1 :	
 [======================================== 29048/29048 ================================>]  Tot: 1m36s | Step: 3ms       
learning rate	0.009975025	
mean gradParam norm	7.7288590271154	
Speed : 0.003351 sec/batch 	
Training PPL : 252.14280286606	
Validation PPL : 178.4052220798	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7	
	
Epoch #2 :	
 [======================================== 29048/29048 ================================>]  Tot: 1m37s | Step: 3ms       
learning rate	0.00995005	
mean gradParam norm	8.4460625165352	
Speed : 0.003365 sec/batch 	
Training PPL : 127.50807904334	
Validation PPL : 141.2727270623	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7	
	
Epoch #3 :	
 [======================================== 29048/29048 ================================>]  Tot: 1m37s | Step: 3ms       
learning rate	0.009925075	
mean gradParam norm	9.040807982285	
Speed : 0.003365 sec/batch 	
Training PPL : 94.050081227786	
Validation PPL : 127.12406725727	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7	
	
Epoch #4 :	
 [======================================== 29048/29048 ================================>]  Tot: 1m37s | Step: 3ms       
learning rate	0.0099001	
mean gradParam norm	9.548776472761	
Speed : 0.003364 sec/batch 	
Training PPL : 75.183904110218	
Validation PPL : 121.60937118905	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7	
	
Epoch #5 :	
 [======================================== 29048/29048 ================================>]  Tot: 1m37s | Step: 3ms       
learning rate	0.009875125	
mean gradParam norm	10.028425562097	
Speed : 0.003355 sec/batch 	
Training PPL : 62.757740662728	
Validation PPL : 121.03703331266	
Found new minima. Saving to /var/storage/shared/mscog/sys/jobs/application_1485367302006_0415/save/rnnlm/ptb:phlrr1015:1485859742:1.t7	

So the training loss looks similar across the two, but generalization seems better with GRU here.
Haven't done any sort of hyperparameter search here, just thought that this could be useful info for users looking to sanity check in the future.

@nicholas-leonard nicholas-leonard merged commit 51b5678 into Element-Research:master Feb 17, 2017
@nicholas-leonard
Copy link
Member

@juesato Thanks for following through on this to the end!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants