Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 0460049

Browse files
sbodensteinpiiswrong
authored andcommitted
[Op] cuDNN RNN Symbol (#2795)
* - first commit * - removed unnecssary commented out code - fixed error in output shape inference * - some renaming - added cudnn destructors * - added dropout * - major refactor - completed forward evaluation * - added parameter size test - fixed bug where cudnnGetRNNParamsSize needs to be called after cudnnSetRNNDescriptor * - checks for contiguous input tensors - more consistent param names - removed 'batch_first' option for now. Might add it later again * - fixed input names * - added backward method * - small fix for in/out names * - fixed bug: parameters can't have underscore * - fixed off-by-two error in weight shape inference for bidirectional net - moved calculated param to cudnn_rnn-inl.h * - added option to control num outputs * - removed lint * - correct handling of backward dependencies * - fix lint * - first commit * - removed unnecssary commented out code - fixed error in output shape inference * - some renaming - added cudnn destructors * - added dropout * - major refactor - completed forward evaluation * - added parameter size test - fixed bug where cudnnGetRNNParamsSize needs to be called after cudnnSetRNNDescriptor * - checks for contiguous input tensors - more consistent param names - removed 'batch_first' option for now. Might add it later again * - fixed input names * - added backward method * - small fix for in/out names * - fixed bug: parameters can't have underscore * - fixed off-by-two error in weight shape inference for bidirectional net - moved calculated param to cudnn_rnn-inl.h * - added option to control num outputs * - removed lint * - correct handling of backward dependencies * - fix lint * - fix type narrowing bug * - fixed incorrect dropout parameter - added dropout states - fixed incorrect handling of variable outputs * - fix incorrect cell state forward handling * - fixed lint by replacing unsigned long long with uint64_t
1 parent 06841a0 commit 0460049

File tree

4 files changed

+891
-0
lines changed

4 files changed

+891
-0
lines changed

0 commit comments

Comments
 (0)