Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

[Op] cuDNN RNN Symbol #2795

Merged
merged 38 commits into from
Jul 24, 2016
Merged

[Op] cuDNN RNN Symbol #2795

merged 38 commits into from
Jul 24, 2016

Conversation

sbodenstein
Copy link
Contributor

This adds an interface to the cuDNN RNN operator. Some issues:

  1. The forward pass in inference mode reproduces https://github.com/soumith/cudnn.torch
  2. The backward mode is currently not working, due to incorrect handling of dropout descriptor. Correct handling of the dropout state needs to be added.
  3. It doesn't currently inherit the MXNet seed for dropout. How is this seed accessed?
  4. This symbol currently only supports data in form [seq length, batch, input size], which is the native cuDNN format. Should add support for [batch, seq, input] as well, but will probably require a temp state + transpose.
  5. Gives an option to return multiple outputs (output + 2 states for LSTM, 1 for others). By default only a single output is returned, but sometimes you require access to the output states (eg generating text).
  6. Currently only support the CUDNN_LINEAR_INPUT option for cudnnRNNInputMode_t.
  7. Uses a single parameter vector. It will be useful to have a Python script to convert this to a dictionary of NDArray's giving each weight + bias for each layer.

See also: #2401

- fixed error in output shape inference
- added cudnn destructors
- completed forward evaluation
- fixed bug where cudnnGetRNNParamsSize needs to be called after cudnnSetRNNDescriptor
- more consistent param names
- removed 'batch_first' option for now. Might add it later again
- fixed error in output shape inference
- added cudnn destructors
- completed forward evaluation
- fixed bug where cudnnGetRNNParamsSize needs to be called after cudnnSetRNNDescriptor
- more consistent param names
- removed 'batch_first' option for now. Might add it later again
@piiswrong
Copy link
Contributor

So the current problem with training is that reservedspace need to be kept as an output but it's size is unknown during shape inference?
A simple fix is to use cudamemalloc to alloc it during op.init. It's not ideal but since it should be a small buffer it's fine for now

@sxjscience
Copy link
Member

@sbodenstein @antinucleon My concern is that resembles of dropout and batch normalization in RNN, which hasn't been included in CuDNN, will soon be standardized and we may need to add these new features to our C++ implementation. Also, RNN in TensorFlow is implemented by combining basic symbols (https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell.py). Anyway, writing a C++ compatible version of CuDNNRNN will do nothing bad. We can still implement the wrapper in the script language.

@sbodenstein
Copy link
Contributor Author

@piiswrong: regarding the space computation, I don't know how its done, as it only depends on a single parameter, the cudnn handle (cudnnDropoutGetStatesSize(cudnnHandle_t handle, size_t * size)). Is it device dependent? If not, why not just use a global variable?

Ok, will use cudaMalloc for now (and free in ~CuDNNRNNOp()).

@sxjscience
Copy link
Member

For [4], I think [seq length, batch, input size] is fine and let's just stick to this layout.

@piiswrong
Copy link
Contributor

@sbodenstein Any update on this?

@sbodenstein
Copy link
Contributor Author

@piiswrong: I've been on vacation the last two days. I should have some time tomorrow (latest Monday) to resolve the remaining issues (found one or two extra bugs in my code, and also deal with dropout states correctly and test Backward against Torch). Apologies for the delay.

@sbodenstein
Copy link
Contributor Author

@piiswrong: actually, just committed a fix that should resolve the dropout issue, and also fixes a few other bugs.

I will spend some time tomorrow testing all the various configurations against Torch. What else do I need to do to merge a first version? A set of Python tests?

- added dropout states
- fixed incorrect handling of variable outputs
@piiswrong
Copy link
Contributor

Testing agains python version would be nice. However since this is GPU only and currently it won't be run on test server anyway, if you can confirm consistency with torch I think it's enough for initial version

@piiswrong
Copy link
Contributor

piiswrong commented Jul 24, 2016

@antinucleon Do you have time to do deepmark lstm?

@sbodenstein
Copy link
Contributor Author

@piiswrong: I reproduce Torch with a wide variety of settings (bidirectional, lstm, gru, etc). I think its ready to be merged.

@piiswrong
Copy link
Contributor

Great. I'll merge it after tests finish

@piiswrong
Copy link
Contributor

could you update to current master

@sbodenstein
Copy link
Contributor Author

@piiswrong: apologies, done.

@piiswrong piiswrong merged commit 0460049 into apache:master Jul 24, 2016
@sbodenstein
Copy link
Contributor Author

@piiswrong: I want to add a python function that creates a view of the parameter NDArray that gives the weights and biases of each layer as individual NDArrays. Where should this function live?

@Godricly
Copy link
Contributor

@sbodenstein : Can you fix this operator for cudnn v5.0? the function parameters are different between 5.0 and 5.1.

@sbodenstein
Copy link
Contributor Author

@Godricly: I tested this only for cudnn v5.0, and there wasn't a problem. Which parameters are different? And what is broken for 5.0?

Also, I assumed they were the same, as the release notes of v5.1 state: "cuDNN 5.1 is fully API compatible with cuDNN 5.0."

@thirdwing
Copy link
Contributor

@sbodenstein Can you share the code you used to reproduce torch results?

seed_), CUDNN_STATUS_SUCCESS);
// RNN descriptors
CHECK_EQ(cudnnCreateRNNDescriptor(&rnn_desc_), CUDNN_STATUS_SUCCESS);
CHECK_EQ(cudnnSetRNNDescriptor(rnn_desc_,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function is different in cudnn 5.0.4, which I used.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you referring to cudnnSetDropoutDescriptor? If so, I just downloaded the "cuDNN User Guide" from "cuDNN v5 (May 12, 2016), for CUDA 7.5" from the cuDNN site. I don't see any difference. I also looked in the user guide under "Download cuDNN v5 (May 27, 2016), for CUDA 8.0 RC". Still no difference. These are the only 5.0 releases available on the cudNN site. So please, can you be super specific as to how you are seeing a difference, and with what?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well...They made some change between cudnn 5.0.4 (April 2016) and cudnn 5.0.5. Previously cudnnSetRNNDescriptor has a input parameter seqLength. I've updated my cudnn. It's not a problem anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you were using a release candidate?

@sbodenstein
Copy link
Contributor Author

@thirdwing: sure, here.

@kikoqiu
Copy link
Contributor

kikoqiu commented Jul 29, 2016

Bug report:
mxnet.symbol.RNN dosn't work with mx.model.FeedForward, as in the mx model assume the first size in shape to be batchsize and will split it for multi-device training, while mxnet.symbol.RNN uses Shape3(total_layers, batch_size, param_.state_size) for rnn init state input.
See exector_manager.py line 219

data_shapes = {k: tuple([slices[i].stop-slices[i].start] + list(v[1:]))
for k, v in train_data.provide_data + train_data.provide_label}

@sbodenstein
Copy link
Contributor Author

@kikoqiu: this is not a bug, its a design decision. You can use SwapAxis to put it into [batch, seq length, in size] form. Otherwise, we could add support for [batch, seq len, in size] as a symbol option.

@kikoqiu
Copy link
Contributor

kikoqiu commented Aug 1, 2016

Hi @sbodenstein , I see it's a design element for RNN in cudnn, and should not be a problem usually. However, the code for mx.model.FeedForward would not work with it, as it assumes all input params to be in the shape of (batch_size,...).

@xlvector
Copy link
Contributor

Is there performance test between CUDNN RNN and previous implements (combined by simple symbols) @sbodenstein

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants