22
33namespace singa {
44
5- RNNHandle::RNNHandle (const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
6- const std::string Rnn_mode, const float Dropout, const bool bidirectional) {
5+ RNNHandle::RNNHandle (const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
6+ const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size) {
7+
8+ CHECK_EQ (input.shape (2 ), Input_size);
9+ batch_size_ = input.shape (1 );
10+ seq_length_= input.shape (0 );
711
812 input_size_ = Input_size;
913 CHECK_GT (input_size_, 0u );
@@ -28,68 +32,62 @@ RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const si
2832 }
2933 // the first constant (4) is the size of float
3034 // the second constant (2, 8, 6) is the number of sets of params
31- int mult = 1 ;
32- if (rnn_mode_ == " relu" || rnn_mode_ == " tanh" )
33- mult *= 1 ;
34- else if (rnn_mode_ == " lstm" )
35- mult *= 4 ;
36- else if (rnn_mode_ == " gru" )
37- mult *= 3 ;
38- if (bidirectional)
39- mult *= 2 ;
40-
41- weight_size = 0 ;
42- for (size_t i = 0 ; i < num_stacks_; i++) {
43- size_t dim = hidden_size_ * (input_size_ + hidden_size_ + 2 );
44- if (i > 0 )
45- dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2 );
46- weight_size += mult * dim;
47- }
35+ weight_size= Weight_size;
36+
4837};
4938
5039#ifdef USE_CUDNN
5140
52- CudnnRNNHandle::CudnnRNNHandle (const vector< Tensor> &inputs , const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
53- const std::string Rnn_mode, const float Dropout, const bool bidirectional):
54- RNNHandle(Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional) {
41+ CudnnRNNHandle::CudnnRNNHandle (const Tensor &input , const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks,
42+ const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size ):
43+ RNNHandle(input, Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional, Weight_size ) {
5544
56- CHECK_GT (inputs.size (), 1u + has_cell_);
57- size_t num_x = inputs.size () - has_cell_ - 1 ;
58-
59- DataType dtype = inputs.at (0 ).data_type ();
60- if (rnn_desc_ != nullptr )
61- CHECK_EQ (dtype_, GetCudnnDataType (dtype))
62- << " Cannot change cudnn data type during training from " << dtype_
63- << " to " << GetCudnnDataType (dtype);
64- else
65- dtype_ = GetCudnnDataType (dtype);
45+ DataType dtype = input.data_type ();
46+ dtype_ = GetCudnnDataType (dtype);
6647
67- UpdateStates (num_x, inputs);
48+ UpdateIODescriptors (input);
49+ ResetHiddenAndCellDescriptors ();
50+ SetRNNDescriptor (input.device ());
51+ UpdateSpaces (seq_length_, input.device ());
6852};
6953
70- void CudnnRNNHandle::UpdateStates (size_t num_x, const vector<Tensor> &inputs) {
71- UpdateIODescriptors (num_x, inputs);
72- size_t new_batch_size = inputs.at (0 ).shape (0 );
73- if (batch_size_ != new_batch_size)
74- ResetHiddenAndCellDescriptors (new_batch_size);
75- if (rnn_desc_ == nullptr )
76- SetRNNDescriptor (inputs.at (0 ).device ());
77- UpdateSpaces (num_x, inputs.at (0 ).device ());
78- batch_size_ = new_batch_size;
79- seq_length_ = num_x;
54+ CudnnRNNHandle::~CudnnRNNHandle () {
55+ if (weight_desc_ != nullptr )
56+ CUDNN_CHECK (cudnnDestroyFilterDescriptor (weight_desc_));
57+ if (dropout_desc_ != nullptr )
58+ CUDNN_CHECK (cudnnDestroyDropoutDescriptor (dropout_desc_));
59+ if (rnn_desc_ != nullptr )
60+ CUDNN_CHECK (cudnnDestroyRNNDescriptor (rnn_desc_));
61+ if (hx_desc_ != nullptr )
62+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (hx_desc_));
63+ if (hy_desc_ != nullptr )
64+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (hy_desc_));
65+ if (cx_desc_ != nullptr )
66+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (cx_desc_));
67+ if (cy_desc_ != nullptr )
68+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (cy_desc_));
69+ if (dhx_desc_ != nullptr )
70+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dhx_desc_));
71+ if (dhy_desc_ != nullptr )
72+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dhy_desc_));
73+ if (dcx_desc_ != nullptr )
74+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dcx_desc_));
75+ if (dcy_desc_ != nullptr )
76+ CUDNN_CHECK (cudnnDestroyTensorDescriptor (dcy_desc_));
77+ DestroyIODescriptors ();
8078};
8179
8280void CudnnRNNHandle::DestroyIODescriptors () {
8381 if (x_descs_ != nullptr ) {
84- for (size_t i = 0 ; i < max_length_ ; i++) {
82+ for (size_t i = 0 ; i < seq_length_ ; i++) {
8583 CUDNN_CHECK (cudnnDestroyTensorDescriptor (x_descs_[i]));
8684 CUDNN_CHECK (cudnnDestroyTensorDescriptor (dx_descs_[i]));
8785 }
8886 delete [] x_descs_;
8987 delete [] dx_descs_;
9088 }
9189 if (y_descs_ != nullptr ) {
92- for (size_t i = 0 ; i < max_length_ ; i++) {
90+ for (size_t i = 0 ; i < seq_length_ ; i++) {
9391 CUDNN_CHECK (cudnnDestroyTensorDescriptor (y_descs_[i]));
9492 CUDNN_CHECK (cudnnDestroyTensorDescriptor (dy_descs_[i]));
9593 }
@@ -98,61 +96,60 @@ void CudnnRNNHandle::DestroyIODescriptors() {
9896 }
9997};
10098
101- void CudnnRNNHandle::UpdateIODescriptors (size_t len, const vector<Tensor> &inputs) {
102- bool reset = false ;
103- if (max_length_ < len) {
104- DestroyIODescriptors ();
105- max_length_ = len;
106- x_descs_ = new cudnnTensorDescriptor_t[len];
107- dx_descs_ = new cudnnTensorDescriptor_t[len];
108- y_descs_ = new cudnnTensorDescriptor_t[len];
109- dy_descs_ = new cudnnTensorDescriptor_t[len];
110- for (size_t i = 0 ; i < len; i++) {
99+
100+ void CudnnRNNHandle::UpdateIODescriptors (const Tensor &input) {
101+ x_descs_ = new cudnnTensorDescriptor_t[seq_length_];
102+ dx_descs_ = new cudnnTensorDescriptor_t[seq_length_];
103+ y_descs_ = new cudnnTensorDescriptor_t[seq_length_];
104+ dy_descs_ = new cudnnTensorDescriptor_t[seq_length_];
105+ for (size_t i = 0 ; i < seq_length_; i++) {
111106 CUDNN_CHECK (cudnnCreateTensorDescriptor (&x_descs_[i]));
112107 CUDNN_CHECK (cudnnCreateTensorDescriptor (&dx_descs_[i]));
113108 CUDNN_CHECK (cudnnCreateTensorDescriptor (&y_descs_[i]));
114109 CUDNN_CHECK (cudnnCreateTensorDescriptor (&dy_descs_[i]));
115110 }
116- reset = true ;
117- }
118111
119- for (size_t i = 0 ; i < len; i++) {
120- CHECK_EQ (inputs[i].shape (1 ), input_size_);
121- if (inputs[i].shape (0 ) != batch_size_ || reset) {
112+ for (size_t i = 0 ; i < seq_length_; i++) {
113+ CHECK_EQ (input.shape (2 ), input_size_);
122114 int d[3 ] = {1 , 1 , 1 }, s[3 ] = {1 , 1 , 1 };
123- d[0 ] = static_cast <int >(inputs[i]. shape ( 0 ) );
115+ d[0 ] = static_cast <int >(batch_size_ );
124116 CHECK_GT (d[0 ], 0 );
125- d[1 ] = static_cast <int >(inputs[i]. shape ( 1 ) );
117+ d[1 ] = static_cast <int >(input_size_ );
126118 s[0 ] = d[1 ] * d[2 ];
127119 s[1 ] = d[2 ];
128120 CUDNN_CHECK (cudnnSetTensorNdDescriptor (x_descs_[i], dtype_, 3 , d, s));
129121 CUDNN_CHECK (cudnnSetTensorNdDescriptor (dx_descs_[i], dtype_, 3 , d, s));
130122
131- d[0 ] = static_cast <int >(inputs[i]. shape ( 0 ) );
123+ d[0 ] = static_cast <int >(batch_size_ );
132124 d[1 ] = static_cast <int >(hidden_size_ * num_directions_);
133125 s[0 ] = d[1 ] * d[2 ];
134126 s[1 ] = d[2 ];
135127 CUDNN_CHECK (cudnnSetTensorNdDescriptor (y_descs_[i], dtype_, 3 , d, s));
136128 CUDNN_CHECK (cudnnSetTensorNdDescriptor (dy_descs_[i], dtype_, 3 , d, s));
137129 }
138- }
139130};
140131
141- void CudnnRNNHandle::ResetHiddenAndCellDescriptors (size_t batch_size ) {
142- if (batch_size_ == 0 ) {
132+ void CudnnRNNHandle::ResetHiddenAndCellDescriptors () {
133+ if (cx_desc_ == nullptr )
143134 CUDNN_CHECK (cudnnCreateTensorDescriptor (&cx_desc_));
135+ if (dcx_desc_ == nullptr )
144136 CUDNN_CHECK (cudnnCreateTensorDescriptor (&dcx_desc_));
137+ if (cy_desc_ == nullptr )
145138 CUDNN_CHECK (cudnnCreateTensorDescriptor (&cy_desc_));
139+ if (dcy_desc_ == nullptr )
146140 CUDNN_CHECK (cudnnCreateTensorDescriptor (&dcy_desc_));
141+ if (hx_desc_ == nullptr )
147142 CUDNN_CHECK (cudnnCreateTensorDescriptor (&hx_desc_));
143+ if (dhx_desc_ == nullptr )
148144 CUDNN_CHECK (cudnnCreateTensorDescriptor (&dhx_desc_));
145+ if (hy_desc_ == nullptr )
149146 CUDNN_CHECK (cudnnCreateTensorDescriptor (&hy_desc_));
147+ if (dhy_desc_ == nullptr )
150148 CUDNN_CHECK (cudnnCreateTensorDescriptor (&dhy_desc_));
151- }
152149
153150 int dim[3 ] = {1 , 1 , 1 };
154151 dim[0 ] = static_cast <int >(num_stacks_ * num_directions_);
155- dim[1 ] = static_cast <int >(batch_size );
152+ dim[1 ] = static_cast <int >(batch_size_ );
156153 dim[2 ] = static_cast <int >(hidden_size_);
157154 int stride[3 ] = {1 , 1 , 1 };
158155 stride[0 ] = dim[1 ] * dim[2 ];
@@ -229,7 +226,7 @@ void CudnnRNNHandle::UpdateSpaces(size_t seq_length, shared_ptr<Device> dev) {
229226 reserve_space_ = Tensor (Shape{count}, dev, kChar );
230227 // reserve_space_.SetValue(0);
231228 }
232- }
229+ };
233230
234231Tensor MergeInputs (size_t num, const vector<Tensor> &in) {
235232 if (num == 1 )
@@ -265,15 +262,14 @@ vector<Tensor> SplitOutput(size_t num, size_t dim,
265262
266263std::vector<Tensor> GpuRNNForwardTraining (const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) {
267264 DataType dtype = input.data_type ();
268- auto dev = input.at ( 0 ). device ();
265+ auto dev = input.device ();
269266
270267
271268 Shape outshape{input.Size () * crh.hidden_size_ / crh.input_size_ * crh.num_directions_ };
272269 Tensor output (outshape, dev, dtype);
273270 // LOG(INFO) << "output size " << output.Size();
274271
275272 Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
276- CHECK_EQ (hx.shape (), state_shape);
277273 Tensor hy (state_shape, dev, dtype);
278274
279275 Tensor cy;
@@ -339,7 +335,6 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens
339335 // LOG(INFO) << "output size " << output.Size();
340336
341337 Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
342- CHECK_EQ (hx.shape (), state_shape);
343338 Tensor hy (state_shape, dev, dtype);
344339
345340 Tensor cy;
@@ -389,7 +384,7 @@ std::vector<Tensor> GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens
389384 return {output, hy, cy};
390385};
391386
392- std::vector<Tensor> GpuRNNBackward (const CudnnRNNHandle &crh, const vector< Tensor> &dY, const Tensor &dh , const Tensor &dc , const vector<Tensor> &cache) {
387+ std::vector<Tensor> GpuRNNBackward (const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy , const Tensor &dcy , const std:: vector<Tensor> &cache) {
393388 const Tensor x = cache[0 ];
394389 const Tensor y = cache[1 ];
395390 const Tensor hx = cache[2 ];
@@ -399,26 +394,22 @@ std::vector<Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector<Tenso
399394 auto dev = y.device ();
400395 auto dtype = y.data_type ();
401396
402-
403397 CHECK_EQ (dY.Size (), y.Size ());
404398
405-
406399 Shape xshape{y.Size () * crh.input_size_ / crh.hidden_size_ / crh.num_directions_ };
407- CHECK_EQ (x.shape (), xshape)
408400 Tensor dx (xshape, dev, dtype);
409401
410402 Tensor dw (W.shape (), dev, dtype);
411403
412404 Shape state_shape{crh.num_stacks_ * crh.num_directions_ , crh.batch_size_ , crh.hidden_size_ };
413- CHECK_EQ (hx.shape (), state_shape)
414405 Tensor dhx (state_shape, dev, dtype);
415406
416407 Tensor dcx;
417408 if (crh.has_cell_ )
418409 dcx.ResetLike (dhx);
419410
420411 dw.SetValue (0 .0f );
421- Block *yb = y.block (), *dyb = dy .block (), *dhyb = dhy.block (),
412+ Block *yb = y.block (), *dyb = dY .block (), *dhyb = dhy.block (),
422413 *dcyb = dcy.block (), *xb = x.block (), *cxb = cx.block (),
423414 *wb = W.block (), *dwb = dw.block (), *hxb = hx.block (),
424415 *dxb = dx.block (), *dhxb = dhx.block (), *dcxb = dcx.block (),
0 commit comments