-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[Op] cuDNN RNN Symbol #2795
[Op] cuDNN RNN Symbol #2795
Conversation
- fixed error in output shape inference
- added cudnn destructors
- completed forward evaluation
- fixed bug where cudnnGetRNNParamsSize needs to be called after cudnnSetRNNDescriptor
- more consistent param names - removed 'batch_first' option for now. Might add it later again
- moved calculated param to cudnn_rnn-inl.h
- fixed error in output shape inference
- added cudnn destructors
- completed forward evaluation
- fixed bug where cudnnGetRNNParamsSize needs to be called after cudnnSetRNNDescriptor
- more consistent param names - removed 'batch_first' option for now. Might add it later again
- moved calculated param to cudnn_rnn-inl.h
So the current problem with training is that reservedspace need to be kept as an output but it's size is unknown during shape inference? |
@sbodenstein @antinucleon My concern is that resembles of dropout and batch normalization in RNN, which hasn't been included in CuDNN, will soon be standardized and we may need to add these new features to our C++ implementation. Also, RNN in TensorFlow is implemented by combining basic symbols (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py). Anyway, writing a C++ compatible version of CuDNNRNN will do nothing bad. We can still implement the wrapper in the script language. |
@piiswrong: regarding the space computation, I don't know how its done, as it only depends on a single parameter, the cudnn handle (cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t * size)). Is it device dependent? If not, why not just use a global variable? Ok, will use cudaMalloc for now (and free in ~CuDNNRNNOp()). |
For [4], I think [seq length, batch, input size] is fine and let's just stick to this layout. |
@sbodenstein Any update on this? |
@piiswrong: I've been on vacation the last two days. I should have some time tomorrow (latest Monday) to resolve the remaining issues (found one or two extra bugs in my code, and also deal with dropout states correctly and test Backward against Torch). Apologies for the delay. |
@piiswrong: actually, just committed a fix that should resolve the dropout issue, and also fixes a few other bugs. I will spend some time tomorrow testing all the various configurations against Torch. What else do I need to do to merge a first version? A set of Python tests? |
- added dropout states - fixed incorrect handling of variable outputs
Testing agains python version would be nice. However since this is GPU only and currently it won't be run on test server anyway, if you can confirm consistency with torch I think it's enough for initial version |
@antinucleon Do you have time to do deepmark lstm? |
@piiswrong: I reproduce Torch with a wide variety of settings (bidirectional, lstm, gru, etc). I think its ready to be merged. |
Great. I'll merge it after tests finish |
could you update to current master |
@piiswrong: apologies, done. |
@piiswrong: I want to add a python function that creates a view of the parameter NDArray that gives the weights and biases of each layer as individual NDArrays. Where should this function live? |
@sbodenstein : Can you fix this operator for cudnn v5.0? the function parameters are different between 5.0 and 5.1. |
@Godricly: I tested this only for cudnn v5.0, and there wasn't a problem. Which parameters are different? And what is broken for 5.0? Also, I assumed they were the same, as the release notes of v5.1 state: "cuDNN 5.1 is fully API compatible with cuDNN 5.0." |
@sbodenstein Can you share the code you used to reproduce torch results? |
seed_), CUDNN_STATUS_SUCCESS); | ||
// RNN descriptors | ||
CHECK_EQ(cudnnCreateRNNDescriptor(&rnn_desc_), CUDNN_STATUS_SUCCESS); | ||
CHECK_EQ(cudnnSetRNNDescriptor(rnn_desc_, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function is different in cudnn 5.0.4, which I used.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you referring to cudnnSetDropoutDescriptor? If so, I just downloaded the "cuDNN User Guide" from "cuDNN v5 (May 12, 2016), for CUDA 7.5" from the cuDNN site. I don't see any difference. I also looked in the user guide under "Download cuDNN v5 (May 27, 2016), for CUDA 8.0 RC". Still no difference. These are the only 5.0 releases available on the cudNN site. So please, can you be super specific as to how you are seeing a difference, and with what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well...They made some change between cudnn 5.0.4 (April 2016) and cudnn 5.0.5. Previously cudnnSetRNNDescriptor
has a input parameter seqLength
. I've updated my cudnn. It's not a problem anymore.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe you were using a release candidate?
@thirdwing: sure, here. |
Bug report:
|
@kikoqiu: this is not a bug, its a design decision. You can use SwapAxis to put it into [batch, seq length, in size] form. Otherwise, we could add support for [batch, seq len, in size] as a symbol option. |
Hi @sbodenstein , I see it's a design element for RNN in cudnn, and should not be a problem usually. However, the code for mx.model.FeedForward would not work with it, as it assumes all input params to be in the shape of (batch_size,...). |
Is there performance test between CUDNN RNN and previous implements (combined by simple symbols) @sbodenstein |
This adds an interface to the cuDNN RNN operator. Some issues:
See also: #2401