-
Notifications
You must be signed in to change notification settings - Fork 311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multi-function recurrent unit #344
Multi-function recurrent unit #344
Conversation
@juesato @JoostvDoorn @jnhwkim can you review this? Minimum requirements: documentation (README.md), unit tests and code. |
@nicholas-leonard For me, I should have to read the paper first. I'll revisit here to catch up. |
@juesato I would really like to see this get merged. Any developments? |
@nicholas-leonard Sorry, I left this hanging. I'l spend the next two hours working on this (speeding up the CMaxTable stuff, adding docs, and unit tests), and if I don't finish then, I'll continue tomorrow. |
d47f636
to
e244e88
Compare
@nicholas-leonard I got some time to spend on this this weekend, but found a bug, and need a bit more time. I'm hoping to finish up after work tomorrow, but I think it's fairly likely I won't have time, in which case I'll finish on Tuesday. |
2d5bc3a
to
51b5678
Compare
@nicholas-leonard I added unit tests and documentation, and I believe this should be ready to merge. As a future reference, I'm going to leave the training curves on PTB here. Command (I added support for Adam, since that's what's used in the original paper):
As a baseline, if we swap out GRU for MuFuRU, we get this curve
So the training loss looks similar across the two, but generalization seems better with GRU here. |
@juesato Thanks for following through on this to the end! |
Implementation of #269
A few notes:
max
andmin
options turned off, because until I merge Make CMaxTable and CMinTable cunn-compatible torch/nn#954, theCMaxTable
andCMinTable
options aren't compatible with CudaTensors. However, this makes the model considerably worse, converging to validation PPL 164 rather than 134.I think this should merge after torch/nn#954 since until then the performance is poor, but I wanted to get feedback on this before adding unit tests and documentation.
Let me know if you have any questions!
@nicholas-leonard