From 122e5b42c3e5b52bb78b0abe1c793f4da23eadce Mon Sep 17 00:00:00 2001 From: xuewanqi Date: Mon, 16 Jul 2018 19:30:04 +0000 Subject: [PATCH 1/5] SINGA-386 Implement RNN operation for autograd - develop origin version of some support funcitons for rnn operation --- src/model/operation/rnn.cc | 407 +++++++++++++++++++++++++++++++++++++ src/model/operation/rnn.h | 63 ++++++ 2 files changed, 470 insertions(+) create mode 100644 src/model/operation/rnn.cc create mode 100644 src/model/operation/rnn.h diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc new file mode 100644 index 0000000000..2ef213912f --- /dev/null +++ b/src/model/operation/rnn.cc @@ -0,0 +1,407 @@ +RecHandle::RecHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional) { + + input_size_ = Input_size; + CHECK_GT(input_size_, 0u); + hidden_size_ = Hidden_size; + CHECK_GT(hidden_size_, 0u); + num_stacks_ = Num_stacks; + CHECK_GT(num_stacks_, 0u); + dropout_ = Dropout; // drop probability + CHECK_GE(dropout_, 0); + + if (bidirectional) + num_directions_ = 2; + else + num_directions_ = 1; + + rnn_mode_ = Rnn_mode; + if (rnn_mode_ == "lstm") { + has_cell_ = true; + } else if (rnn_mode_ !="relu" && rnn_mode_ != "tanh" && rnn_mode_ != "gru") { + LOG(FATAL) << "RNN memory unit (mode) of " << rnn_mode_ + << " is not supported Please use 'relu', 'tanh', 'lstm' and 'gru'"; + } + // the first constant (4) is the size of float + // the second constant (2, 8, 6) is the number of sets of params + int mult = 1; + if (rnn_mode_ == "relu" || rnn_mode_ == "tanh") + mult *= 1; + else if (rnn_mode_ == "lstm") + mult *= 4; + else if (rnn_mode_ == "gru") + mult *= 3; + if (direction_ == "bidirectional") + mult *= 2; + + weight_size = 0; + for (size_t i = 0; i < num_stacks_; i++) { + size_t dim = hidden_size_ * (in_sample[0] + hidden_size_ + 2); + if (i > 0) + dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2); + weight_size += mult * dim; + } +} + + +CudnnRecHandle::CudnnRecHandle(const vector &inputs,const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional): + RecHandle(Input_size, Hidden_size, Num_stacks, nonlinearity, bias, dropout, bidirectional){ + + CHECK_GT(inputs.size(), 1u + has_cell_); + size_t num_x = inputs.size() - has_cell_ - 1; + + UpdateStates(num_x, inputs); + } + +void CudnnRecHandle::UpdateStates(size_t num_x, const vector &inputs) { + UpdateIODescriptors(num_x, inputs); + size_t new_batch_size = inputs.at(0).shape(0); + if (batch_size_ != new_batch_size) + ResetHiddenAndCellDescriptors(new_batch_size); + if (rnn_desc_ == nullptr) + SetRNNDescriptor(inputs.at(0).device()); + UpdateSpaces(num_x, inputs.at(0).device()); + batch_size_ = new_batch_size; + seq_length_ = num_x; +} + +void CudnnRecHandle::UpdateIODescriptors(size_t len, const vector &inputs) { + bool reset = false; + if (max_length_ < len) { + DestroyIODescriptors(); + max_length_ = len; + x_descs_ = new cudnnTensorDescriptor_t[len]; + dx_descs_ = new cudnnTensorDescriptor_t[len]; + y_descs_ = new cudnnTensorDescriptor_t[len]; + dy_descs_ = new cudnnTensorDescriptor_t[len]; + for (size_t i = 0; i < len; i++) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i])); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dx_descs_[i])); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i])); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dy_descs_[i])); + } + reset = true; + } + + for (size_t i = 0; i < len; i++) { + CHECK_EQ(inputs[i].shape(1), input_size_); + if (inputs[i].shape(0) != batch_size_ || reset) { + int d[3] = {1, 1, 1}, s[3] = {1, 1, 1}; + d[0] = static_cast(inputs[i].shape(0)); + CHECK_GT(d[0], 0); + d[1] = static_cast(inputs[i].shape(1)); + s[0] = d[1] * d[2]; + s[1] = d[2]; + CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], dtype_, 3, d, s)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s)); + + d[0] = static_cast(inputs[i].shape(0)); + d[1] = static_cast(hidden_size_ * num_directions_); + s[0] = d[1] * d[2]; + s[1] = d[2]; + CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s)); + } + } +} + +void CudnnRecHandle::ResetHiddenAndCellDescriptors(size_t batch_size) { + if (batch_size_ == 0) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcy_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhx_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_)); + CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhy_desc_)); + } + + int dim[3] = {1, 1, 1}; + dim[0] = static_cast(num_stacks_ * num_directions_); + dim[1] = static_cast(batch_size); + dim[2] = static_cast(hidden_size_); + int stride[3] = {1, 1, 1}; + stride[0] = dim[1] * dim[2]; + stride[1] = dim[2]; + CUDNN_CHECK(cudnnSetTensorNdDescriptor(hx_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dhx_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(hy_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dhy_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(cx_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcx_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, dtype_, 3, dim, stride)); + CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcy_desc_, dtype_, 3, dim, stride)); +} + +void CudnnRecHandle::SetRNNDescriptor(shared_ptr dev) { + auto ctx = dev->context(0); + CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_)); + size_t state_size; + CUDNN_CHECK(cudnnDropoutGetStatesSize(ctx->cudnn_handle, &state_size)); + dropout_state_ = Tensor(Shape{state_size}, dev, kChar); + CUDNN_CHECK(cudnnSetDropoutDescriptor( + dropout_desc_, ctx->cudnn_handle, 1 - dropout_, // keep probability + dropout_state_.block()->mutable_data(), state_size, seed_)); + + CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_)); + cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; + //if (input_mode_ == "skip") + //input_mode = CUDNN_SKIP_INPUT; + + cudnnDirectionMode_t direction = CUDNN_UNIDIRECTIONAL; + if (num_directions_ == 2) + direction = CUDNN_BIDIRECTIONAL; + + cudnnRNNMode_t rnn_mode = CUDNN_LSTM; + if (rnn_mode_ == "relu") + rnn_mode = CUDNN_RNN_RELU; + else if (rnn_mode_ == "tanh") + rnn_mode = CUDNN_RNN_TANH; + else if (rnn_mode_ == "gru") + rnn_mode = CUDNN_GRU; +#if CUDNN_MAJOR <= 5 + CUDNN_CHECK(cudnnSetRNNDescriptor(rnn_desc_, hidden_size_, num_stacks_, + dropout_desc_, input_mode, direction, + rnn_mode, dtype_)); +#else + CUDNN_CHECK(cudnnSetRNNDescriptor(ctx->cudnn_handle, rnn_desc_, hidden_size_, num_stacks_, + dropout_desc_, input_mode, direction, + rnn_mode, CUDNN_RNN_ALGO_STANDARD, dtype_)); +#endif + size_t weight_size_; + CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0], + &weight_size_, dtype_)); + // check the size manually calculated + CHECK_EQ(weight_size_, weight_size * sizeof(float)); + int filter_dim[3] = {static_cast(weight_size_), 1, 1}; + CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_)); + CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, dtype_, + CUDNN_TENSOR_NCHW, 3, filter_dim)); +} + +void CudnnRecHandle::UpdateSpaces(size_t seq_length, shared_ptr dev) { + size_t count; + auto ctx = dev->context(0); + CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_, + seq_length, x_descs_, &count)); + if (workspace_.Size() != count) { + workspace_ = Tensor(Shape{count}, dev, kChar); + // workspace_.SetValue(0); + } + + CUDNN_CHECK(cudnnGetRNNTrainingReserveSize(ctx->cudnn_handle, rnn_desc_, + seq_length, x_descs_, &count)); + if (reserve_space_.Size() != count) { + reserve_space_ = Tensor(Shape{count}, dev, kChar); + // reserve_space_.SetValue(0); + } +} + +Tensor MergeInputs(size_t num, const vector &in) { + if (num == 1) + return in.at(0); + size_t size = 0; + for (size_t i = 0; i < num; i++) size += in.at(i).Size(); + Tensor out(Shape{size}, in.at(0).device(), in.at(0).data_type()); + for (size_t i = 0, offset = 0; i < num; i++) { + CopyDataToFrom(&out, in.at(i), in.at(i).Size(), offset); + offset += in.at(i).Size(); + } + return out; +} + +vector SplitOutput(size_t num, size_t dim, + const vector &in, + const Tensor output) { + vector outputs; + if (num == 1) { + outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim})); + } else { + for (size_t i = 0, offset = 0; offset < output.Size(); i++) { + Shape s{in.at(i).shape(0), dim}; + Tensor out(s, output.device(), output.data_type()); + CopyDataToFrom(&out, output, out.Size(), 0, offset); + outputs.push_back(out); + offset += out.Size(); + } + CHECK_EQ(num, outputs.size()); + } + return outputs; +} + +const std::vector> GpuRecForwardTraining(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W){ + DataType dtype = inputs.at(0).data_type(); + auto dev = inputs.at(0).device(); + + CHECK_GT(inputs.size(), 1u + has_cell_); + size_t num_x = inputs.size() - has_cell_ - 1; + Tensor input = MergeInputs(num_x, inputs); + + if (rnn_desc_ != nullptr) + CHECK_EQ(dtype_, GetCudnnDataType(dtype)) + << "Cannot change cudnn data type during training from " << dtype_ + << " to " << GetCudnnDataType(dtype); + else + dtype_ = GetCudnnDataType(dtype); + + Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; + Tensor output(outshape, dev, dtype); + // LOG(INFO) << "output size " << output.Size(); + Tensor hx = inputs.at(num_x); + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + Tensor hy(state_shape, dev, dtype); + + Tensor cy, cx; + if (crh.has_cell_) { + cx = inputs.at(num_x + 1); + cy.ResetLike(hy); + } + + int did = input.device()->id(); + CHECK_EQ(did, output.device()->id()); + if (hx.Size()) { + CHECK_EQ(did, hx.device()->id()); + CHECK_EQ(hx.device()->lang(), kCuda); + } + if (cx.Size()) { + CHECK_EQ(did, cx.device()->id()); + CHECK_EQ(cx.device()->lang(), kCuda); + } + CHECK_EQ(did, W.device()->id()); + CHECK_EQ(did, crh.workspace_.device()->id()); + CHECK_EQ(input.device()->lang(), kCuda); + CHECK_EQ(output.device()->lang(), kCuda); + CHECK_EQ(W.device()->lang(), kCuda); + CHECK_EQ(crh.workspace_.device()->lang(), kCuda); + CHECK_EQ(crh.reserve_space_.device()->lang(), kCuda); + CHECK_EQ(did, crh.reserve_space_.device()->id()); + + Block *inb = input.block(), *outb = output.block(), + *wb = W.block(), *hxb = hx.block(), *cxb = cx.block(), + *hyb = hy.block(), *cyb = cy.block(), + *wspace = crh.workspace_.block(), + *rspace = crh.reserve_space_.block(); + + dev->Exec( + [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNForwardTraining( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, inb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.weight_desc_, wb->data(), + crh.y_descs_, outb->mutable_data(), + crh.hy_desc_, hyb->mutable_data(), + crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), + crh.workspace_.Size(), rspace->mutable_data(), + crh.reserve_space_.Size()); + // clang-format on + }, + {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); + + auto outputs = + SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output); + outputs.push_back(hy); + if (has_cell_) outputs.push_back(cy); + + std::vector cache; + cache.push_back(input); + cache.push_back(output); + cache.push_back(hx); + cache.push_back(cx); + cache.push_back(W); + + return {outputs, cache}; +} + +const std::vector GpuRecForwardInference(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W){ + +} + +const std::pair, Tensor> GpuRecBackward(const CudnnRecHandle &crh, const vector &grads, const vector &cache){ + const Tensor x= cache[0]; + const Tensor y= cache[1]; + const Tensor hx= cache[2]; + const Tensor cx= cache[3]; + const Tensor W= cache[4]; + + auto dev = y.device(); + auto dtype = y.data_type(); + + CHECK_GT(grads.size(), 1u + crh.has_cell_); + size_t num_dy = grads.size() - crh.has_cell_ - 1; + CHECK_EQ(num_dy, crh.seq_length_); + const Tensor dy = MergeInputs(num_dy, grads); + CHECK_EQ(dy.Size(), y.Size()); + const Tensor dhy = grads.at(num_dy); + Tensor dcy; + if (crh.has_cell_) + dcy = grads.at(num_dy + 1); + + Shape xshape{y.Size() * crh.input_size_ / crh.hidden_size_ / crh.num_directions_}; + Tensor dx(xshape, dev, dtype); + Tensor dw(W.shape(), dev, dtype); + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + Tensor dhx(state_shape, dev, dtype); + Tensor dcx; + if (crh.has_cell_) + dcx.ResetLike(dhx); + dw.SetValue(0.0f); + Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(), + *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(), + *wb = W.block(), *dwb = dw.block(), *hxb = hx.block(), + *dxb = dx.block(), *dhxb = dhx.block(), *dcxb = dcx.block(), + *wspace = crh.workspace_.block(), *rspace = crh.reserve_space_.block(); + + y.device()->Exec( + [yb, dyb, dhyb, dcyb, xb, cxb, wb, dwb, hxb, dxb, dhxb, dcxb, wspace, + rspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNBackwardData( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.y_descs_, yb->data(), + crh.dy_descs_, dyb->data(), + crh.dhy_desc_, dhyb == nullptr ? nullptr : dhyb->data(), + crh.dcy_desc_, dcyb == nullptr ? nullptr : dcyb->data(), + crh.weight_desc_, wb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.dx_descs_, dxb->mutable_data(), + crh.dhx_desc_, dhxb->mutable_data(), + crh.dcx_desc_, dcxb == nullptr ? nullptr : dcxb->mutable_data(), + wspace->mutable_data(), crh.workspace_.Size(), + rspace->mutable_data(), crh.reserve_space_.Size()); + cudnnRNNBackwardWeights( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, xb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.y_descs_, yb->data(), + wspace->data(), crh.workspace_.Size(), + crh.dweight_desc_, dwb->mutable_data(), + rspace->data(), crh.reserve_space_.Size()); + // clang-format on + }, + {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace}, + {dxb, dwb, dhxb, dcxb, wspace, rspace}); + + auto data_grads = SplitOutput(num_dy, crh.input_size_, grads, dx); + data_grads.push_back(dhx); + if (has_cell_) + data_grads.push_back(dcx); + + return std::make_pair(data_grads, dw); +} + + + + diff --git a/src/model/operation/rnn.h b/src/model/operation/rnn.h new file mode 100644 index 0000000000..5ca5c2151c --- /dev/null +++ b/src/model/operation/rnn.h @@ -0,0 +1,63 @@ +class RecHandle { +public: + RecHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional); + + size_t input_size_; + size_t hidden_size_; + size_t num_stacks_; + float dropout_; + size_t seed_ = 0x1234567; + size_t num_directions_; + std::string rnn_mode_; + bool has_cell; + size_t weight_size; + + size_t batch_size = 0; + size_t seq_length_ = 0; + size_t max_length_ = 0; +} + +class CudnnRecHandle: public RecHandle { +public: + CudnnRecHandle(const vector &inputs,const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional); + void UpdateStates(size_t num_x, const vector &inputs); + void UpdateIODescriptors(size_t len, const vector &inputs); + void ResetHiddenAndCellDescriptors(size_t batch_size); + void SetRNNDescriptor(shared_ptr dev); + void UpdateSpaces(size_t seq_length, shared_ptr dev); + + cudnnTensorDescriptor_t* x_descs_ = nullptr; + cudnnTensorDescriptor_t* dx_descs_ = nullptr; + cudnnTensorDescriptor_t* y_descs_ = nullptr; + cudnnTensorDescriptor_t* dy_descs_ = nullptr; + cudnnTensorDescriptor_t hx_desc_ = nullptr; + cudnnTensorDescriptor_t dhx_desc_ = nullptr; + cudnnTensorDescriptor_t cx_desc_ = nullptr; + cudnnTensorDescriptor_t dcx_desc_ = nullptr; + cudnnTensorDescriptor_t hy_desc_ = nullptr; + cudnnTensorDescriptor_t dhy_desc_ = nullptr; + cudnnTensorDescriptor_t cy_desc_ = nullptr; + cudnnTensorDescriptor_t dcy_desc_ = nullptr; + cudnnFilterDescriptor_t weight_desc_ = nullptr; + cudnnFilterDescriptor_t dweight_desc_ = nullptr; + cudnnRNNDescriptor_t rnn_desc_ = nullptr; + cudnnDropoutDescriptor_t dropout_desc_ = nullptr; + cudnnDataType_t dtype_ = CUDNN_DATA_FLOAT; + Tensor workspace_; + Tensor reserve_space_; + Tensor dropout_state_; +} + +Tensor MergeInputs(size_t num, const vector &in); + +vector SplitOutput(size_t num, size_t dim, + const vector &in, + const Tensor output); + +std::vector> GpuRecForwardTraining(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W); + +std::vector GpuRecForwardInference(const CudnnRecHandle &crh, const vector &inputs,const Tensor &W); + +const std::pair, Tensor> GpuRecBackward(const CudnnRecHandle &crh, const vector &grads, const vector &cache); From 999b19cfe2a25e5c29fa67fe43fe85638d42677e Mon Sep 17 00:00:00 2001 From: xuewanqi Date: Tue, 17 Jul 2018 15:50:16 +0000 Subject: [PATCH 2/5] SINGA-386 Implement RNN operation for autograd - fix bugs in cpp parts, the codes can be made without error. --- python/singa/autograd.py | 40 ++++++- src/model/operation/rnn.cc | 227 ++++++++++++++++++++++++++----------- src/model/operation/rnn.h | 55 ++++++--- 3 files changed, 238 insertions(+), 84 deletions(-) diff --git a/python/singa/autograd.py b/python/singa/autograd.py index 56b5498a35..1e3c1b8cbd 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -798,7 +798,7 @@ def __call__(self, x): self.handle.device_id = x.device.id() y = batchnorm_2d(self.handle, x, self.scale, self.bias, - self.running_mean, self.running_var) + self.running_mean, self.running_var) return y @@ -962,3 +962,41 @@ def __init__(self, kernel_size, stride=None, padding=0): stride = kernel_size super(MaxPool2d, self).__init__( (1, kernel_size), (0, stride), (0, padding), False) + + +class _RNN(Operation): + + def __init__(self, handle): + self.handle = handle + + def forward(self, X, W): + + if self.handle.device_id == -1: + raise NotImplementedError + else: + if training: + out, self.cache = singa.GpuRNNForwardTraining( + self.handle, X, W) + else: + out = singa.GpuRNNForwardInference(self.handle, X, W) + return out + + def backward(self, dY): + assert training is True and hasattr( + self, 'cache'), 'Please set training as True before do BP. ' + + if dY.device().id() != self.handle.device_id: + dY.ToDevice(self.inputs[0].device()) + + if self.handle.device_id == -1: + raise NotImplementedError + else: + dX, dW = singa.GpuRNNBackward(self.handle, dY, self.cache) + return dX, dW + + +def rnn(): + pass + + +class RNN(Layer): diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc index 2ef213912f..afeba67686 100644 --- a/src/model/operation/rnn.cc +++ b/src/model/operation/rnn.cc @@ -1,6 +1,10 @@ -RecHandle::RecHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string Rnn_mode, const float Dropout, const bool bidirectional) { - +#include "./rnn.h" + +namespace singa { + +RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional) { + input_size_ = Input_size; CHECK_GT(input_size_, 0u); hidden_size_ = Hidden_size; @@ -18,9 +22,9 @@ RecHandle::RecHandle(const size_t Input_size, const size_t Hidden_size, const si rnn_mode_ = Rnn_mode; if (rnn_mode_ == "lstm") { has_cell_ = true; - } else if (rnn_mode_ !="relu" && rnn_mode_ != "tanh" && rnn_mode_ != "gru") { + } else if (rnn_mode_ != "relu" && rnn_mode_ != "tanh" && rnn_mode_ != "gru") { LOG(FATAL) << "RNN memory unit (mode) of " << rnn_mode_ - << " is not supported Please use 'relu', 'tanh', 'lstm' and 'gru'"; + << " is not supported Please use 'relu', 'tanh', 'lstm' and 'gru'"; } // the first constant (4) is the size of float // the second constant (2, 8, 6) is the number of sets of params @@ -31,30 +35,39 @@ RecHandle::RecHandle(const size_t Input_size, const size_t Hidden_size, const si mult *= 4; else if (rnn_mode_ == "gru") mult *= 3; - if (direction_ == "bidirectional") + if (bidirectional) mult *= 2; weight_size = 0; for (size_t i = 0; i < num_stacks_; i++) { - size_t dim = hidden_size_ * (in_sample[0] + hidden_size_ + 2); + size_t dim = hidden_size_ * (input_size_ + hidden_size_ + 2); if (i > 0) dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2); weight_size += mult * dim; } -} +}; +#ifdef USE_CUDNN -CudnnRecHandle::CudnnRecHandle(const vector &inputs,const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional): - RecHandle(Input_size, Hidden_size, Num_stacks, nonlinearity, bias, dropout, bidirectional){ +CudnnRNNHandle::CudnnRNNHandle(const vector &inputs, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional): + RNNHandle(Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional) { CHECK_GT(inputs.size(), 1u + has_cell_); size_t num_x = inputs.size() - has_cell_ - 1; + DataType dtype = inputs.at(0).data_type(); + if (rnn_desc_ != nullptr) + CHECK_EQ(dtype_, GetCudnnDataType(dtype)) + << "Cannot change cudnn data type during training from " << dtype_ + << " to " << GetCudnnDataType(dtype); + else + dtype_ = GetCudnnDataType(dtype); + UpdateStates(num_x, inputs); - } +}; -void CudnnRecHandle::UpdateStates(size_t num_x, const vector &inputs) { +void CudnnRNNHandle::UpdateStates(size_t num_x, const vector &inputs) { UpdateIODescriptors(num_x, inputs); size_t new_batch_size = inputs.at(0).shape(0); if (batch_size_ != new_batch_size) @@ -64,9 +77,28 @@ void CudnnRecHandle::UpdateStates(size_t num_x, const vector &inputs) { UpdateSpaces(num_x, inputs.at(0).device()); batch_size_ = new_batch_size; seq_length_ = num_x; -} +}; + +void CudnnRNNHandle::DestroyIODescriptors() { + if (x_descs_ != nullptr) { + for (size_t i = 0; i < max_length_; i++) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i])); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i])); + } + delete [] x_descs_; + delete [] dx_descs_; + } + if (y_descs_ != nullptr) { + for (size_t i = 0; i < max_length_; i++) { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i])); + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i])); + } + delete [] y_descs_; + delete [] dy_descs_; + } +}; -void CudnnRecHandle::UpdateIODescriptors(size_t len, const vector &inputs) { +void CudnnRNNHandle::UpdateIODescriptors(size_t len, const vector &inputs) { bool reset = false; if (max_length_ < len) { DestroyIODescriptors(); @@ -104,9 +136,9 @@ void CudnnRecHandle::UpdateIODescriptors(size_t len, const vector &input CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s)); } } -} +}; -void CudnnRecHandle::ResetHiddenAndCellDescriptors(size_t batch_size) { +void CudnnRNNHandle::ResetHiddenAndCellDescriptors(size_t batch_size) { if (batch_size_ == 0) { CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_)); CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_)); @@ -133,9 +165,9 @@ void CudnnRecHandle::ResetHiddenAndCellDescriptors(size_t batch_size) { CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcx_desc_, dtype_, 3, dim, stride)); CUDNN_CHECK(cudnnSetTensorNdDescriptor(cy_desc_, dtype_, 3, dim, stride)); CUDNN_CHECK(cudnnSetTensorNdDescriptor(dcy_desc_, dtype_, 3, dim, stride)); -} +}; -void CudnnRecHandle::SetRNNDescriptor(shared_ptr dev) { +void CudnnRNNHandle::SetRNNDescriptor(shared_ptr dev) { auto ctx = dev->context(0); CUDNN_CHECK(cudnnCreateDropoutDescriptor(&dropout_desc_)); size_t state_size; @@ -148,7 +180,7 @@ void CudnnRecHandle::SetRNNDescriptor(shared_ptr dev) { CUDNN_CHECK(cudnnCreateRNNDescriptor(&rnn_desc_)); cudnnRNNInputMode_t input_mode = CUDNN_LINEAR_INPUT; //if (input_mode_ == "skip") - //input_mode = CUDNN_SKIP_INPUT; + //input_mode = CUDNN_SKIP_INPUT; cudnnDirectionMode_t direction = CUDNN_UNIDIRECTIONAL; if (num_directions_ == 2) @@ -179,9 +211,9 @@ void CudnnRecHandle::SetRNNDescriptor(shared_ptr dev) { CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_)); CUDNN_CHECK(cudnnSetFilterNdDescriptor(weight_desc_, dtype_, CUDNN_TENSOR_NCHW, 3, filter_dim)); -} +}; -void CudnnRecHandle::UpdateSpaces(size_t seq_length, shared_ptr dev) { +void CudnnRNNHandle::UpdateSpaces(size_t seq_length, shared_ptr dev) { size_t count; auto ctx = dev->context(0); CUDNN_CHECK(cudnnGetRNNWorkspaceSize(ctx->cudnn_handle, rnn_desc_, @@ -210,11 +242,11 @@ Tensor MergeInputs(size_t num, const vector &in) { offset += in.at(i).Size(); } return out; -} +}; vector SplitOutput(size_t num, size_t dim, - const vector &in, - const Tensor output) { + const vector &in, + const Tensor output) { vector outputs; if (num == 1) { outputs.push_back(Reshape(output, Shape{in.at(0).shape(0), dim})); @@ -229,30 +261,23 @@ vector SplitOutput(size_t num, size_t dim, CHECK_EQ(num, outputs.size()); } return outputs; -} +}; -const std::vector> GpuRecForwardTraining(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W){ +std::vector> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W) { DataType dtype = inputs.at(0).data_type(); auto dev = inputs.at(0).device(); - CHECK_GT(inputs.size(), 1u + has_cell_); - size_t num_x = inputs.size() - has_cell_ - 1; + CHECK_GT(inputs.size(), 1u + crh.has_cell_); + size_t num_x = inputs.size() - crh.has_cell_ - 1; Tensor input = MergeInputs(num_x, inputs); - if (rnn_desc_ != nullptr) - CHECK_EQ(dtype_, GetCudnnDataType(dtype)) - << "Cannot change cudnn data type during training from " << dtype_ - << " to " << GetCudnnDataType(dtype); - else - dtype_ = GetCudnnDataType(dtype); - Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; Tensor output(outshape, dev, dtype); // LOG(INFO) << "output size " << output.Size(); Tensor hx = inputs.at(num_x); Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; Tensor hy(state_shape, dev, dtype); - + Tensor cy, cx; if (crh.has_cell_) { cx = inputs.at(num_x + 1); @@ -285,30 +310,30 @@ const std::vector> GpuRecForwardTraining(const CudnnRecHandl *rspace = crh.reserve_space_.block(); dev->Exec( - [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, &crh](Context * ctx) { - // clang-format off - cudnnRNNForwardTraining( - ctx->cudnn_handle, - crh.rnn_desc_, - crh.seq_length_, - crh.x_descs_, inb->data(), - crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), - crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), - crh.weight_desc_, wb->data(), - crh.y_descs_, outb->mutable_data(), - crh.hy_desc_, hyb->mutable_data(), - crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), - wspace->mutable_data(), - crh.workspace_.Size(), rspace->mutable_data(), - crh.reserve_space_.Size()); - // clang-format on - }, - {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); + [inb, outb, wb, hxb, cxb, hyb, cyb, wspace, rspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNForwardTraining( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, inb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.weight_desc_, wb->data(), + crh.y_descs_, outb->mutable_data(), + crh.hy_desc_, hyb->mutable_data(), + crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), + crh.workspace_.Size(), rspace->mutable_data(), + crh.reserve_space_.Size()); + // clang-format on + }, + {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); auto outputs = SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output); outputs.push_back(hy); - if (has_cell_) outputs.push_back(cy); + if (crh.has_cell_) outputs.push_back(cy); std::vector cache; cache.push_back(input); @@ -318,18 +343,82 @@ const std::vector> GpuRecForwardTraining(const CudnnRecHandl cache.push_back(W); return {outputs, cache}; -} +}; + +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W) { + DataType dtype = inputs.at(0).data_type(); + auto dev = inputs.at(0).device(); -const std::vector GpuRecForwardInference(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W){ + CHECK_GT(inputs.size(), 1u + crh.has_cell_); + size_t num_x = inputs.size() - crh.has_cell_ - 1; + Tensor input = MergeInputs(num_x, inputs); -} + Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; + Tensor output(outshape, dev, dtype); + // LOG(INFO) << "output size " << output.Size(); + Tensor hx = inputs.at(num_x); + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + Tensor hy(state_shape, dev, dtype); + + Tensor cy, cx; + if (crh.has_cell_) { + cx = inputs.at(num_x + 1); + cy.ResetLike(hy); + } + + int did = input.device()->id(); + CHECK_EQ(did, output.device()->id()); + if (hx.Size()) { + CHECK_EQ(did, hx.device()->id()); + CHECK_EQ(hx.device()->lang(), kCuda); + } + if (cx.Size()) { + CHECK_EQ(did, cx.device()->id()); + CHECK_EQ(cx.device()->lang(), kCuda); + } + CHECK_EQ(did, W.device()->id()); + CHECK_EQ(did, crh.workspace_.device()->id()); + CHECK_EQ(input.device()->lang(), kCuda); + CHECK_EQ(output.device()->lang(), kCuda); + CHECK_EQ(W.device()->lang(), kCuda); + CHECK_EQ(crh.workspace_.device()->lang(), kCuda); -const std::pair, Tensor> GpuRecBackward(const CudnnRecHandle &crh, const vector &grads, const vector &cache){ - const Tensor x= cache[0]; - const Tensor y= cache[1]; - const Tensor hx= cache[2]; - const Tensor cx= cache[3]; - const Tensor W= cache[4]; + Block *inb = input.block(), *outb = output.block(), + *wb = W.block(), *hxb = hx.block(), *cxb = cx.block(), + *hyb = hy.block(), *cyb = cy.block(), + *wspace = crh.workspace_.block(); + + dev->Exec([inb, outb, wb, hxb, cxb, hyb, cyb, wspace, &crh](Context * ctx) { + // clang-format off + cudnnRNNForwardInference( + ctx->cudnn_handle, + crh.rnn_desc_, + crh.seq_length_, + crh.x_descs_, inb->data(), + crh.hx_desc_, hxb == nullptr ? nullptr : hxb->data(), + crh.cx_desc_, cxb == nullptr ? nullptr : cxb->data(), + crh.weight_desc_, wb->data(), + crh.y_descs_, outb->mutable_data(), + crh.hy_desc_, hyb->mutable_data(), + crh.cy_desc_, cyb == nullptr ? nullptr : cyb->mutable_data(), + wspace->mutable_data(), crh.workspace_.Size()); + // clang-format on + }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace}); + + auto outputs = + SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output); + outputs.push_back(hy); + if (crh.has_cell_) outputs.push_back(cy); + + return outputs; +}; + +std::pair, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector &grads, const vector &cache) { + const Tensor x = cache[0]; + const Tensor y = cache[1]; + const Tensor hx = cache[2]; + const Tensor cx = cache[3]; + const Tensor W = cache[4]; auto dev = y.device(); auto dtype = y.data_type(); @@ -396,12 +485,14 @@ const std::pair, Tensor> GpuRecBackward(const CudnnRecHandle &crh auto data_grads = SplitOutput(num_dy, crh.input_size_, grads, dx); data_grads.push_back(dhx); - if (has_cell_) + if (crh.has_cell_) data_grads.push_back(dcx); return std::make_pair(data_grads, dw); -} +}; +#endif // USE_CUDNN +} // namespace singa diff --git a/src/model/operation/rnn.h b/src/model/operation/rnn.h index 5ca5c2151c..0dbbac9974 100644 --- a/src/model/operation/rnn.h +++ b/src/model/operation/rnn.h @@ -1,7 +1,23 @@ -class RecHandle { +#ifndef SINGA_MODEL_OPERATION_CUDNN_RNN_H_ +#define SINGA_MODEL_OPERATION_CUDNN_RNN_H_ + +#include +#include +#include "singa/core/tensor.h" + + +#ifdef USE_CUDNN +#include +#include "../layer/cudnn_utils.h" +#endif // USE_CUDNN + + +namespace singa { + +class RNNHandle { public: - RecHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional); + RNNHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional); size_t input_size_; size_t hidden_size_; @@ -10,24 +26,28 @@ class RecHandle { size_t seed_ = 0x1234567; size_t num_directions_; std::string rnn_mode_; - bool has_cell; + bool has_cell_; size_t weight_size; - size_t batch_size = 0; + size_t batch_size_ = 0; size_t seq_length_ = 0; size_t max_length_ = 0; -} +}; + +#ifdef USE_CUDNN -class CudnnRecHandle: public RecHandle { +class CudnnRNNHandle: public RNNHandle { public: - CudnnRecHandle(const vector &inputs,const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string nonlinearity, const bool bias, const float dropout, const bool bidirectional); + CudnnRNNHandle(const vector &inputs, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional); void UpdateStates(size_t num_x, const vector &inputs); + void DestroyIODescriptors(); void UpdateIODescriptors(size_t len, const vector &inputs); void ResetHiddenAndCellDescriptors(size_t batch_size); void SetRNNDescriptor(shared_ptr dev); void UpdateSpaces(size_t seq_length, shared_ptr dev); + cudnnTensorDescriptor_t* x_descs_ = nullptr; cudnnTensorDescriptor_t* dx_descs_ = nullptr; cudnnTensorDescriptor_t* y_descs_ = nullptr; @@ -48,16 +68,21 @@ class CudnnRecHandle: public RecHandle { Tensor workspace_; Tensor reserve_space_; Tensor dropout_state_; -} +}; Tensor MergeInputs(size_t num, const vector &in); vector SplitOutput(size_t num, size_t dim, - const vector &in, - const Tensor output); + const vector &in, + const Tensor output); + +std::vector> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W); + +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W); -std::vector> GpuRecForwardTraining(const CudnnRecHandle &crh, const vector &inputs, const Tensor &W); +std::pair, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector &grads, const vector &cache); -std::vector GpuRecForwardInference(const CudnnRecHandle &crh, const vector &inputs,const Tensor &W); +#endif // USE_CUDNN -const std::pair, Tensor> GpuRecBackward(const CudnnRecHandle &crh, const vector &grads, const vector &cache); +} // namespace singa +#endif // SINGA_MODEL_OPERATION_CUDNN_RNN_H_ From e3fba05ff79a1cb608ed7a10702760e956f42f87 Mon Sep 17 00:00:00 2001 From: xuewanqi Date: Thu, 19 Jul 2018 15:48:09 +0000 Subject: [PATCH 3/5] SINGA-386 Implement RNN operation for autograd - redesign some APIs to adapt to autograd --- python/singa/autograd.py | 125 +++++++++++++++++++++++++++++++++---- src/model/operation/rnn.cc | 79 ++++++++--------------- src/model/operation/rnn.h | 7 +-- 3 files changed, 141 insertions(+), 70 deletions(-) mode change 100644 => 100755 src/model/operation/rnn.cc mode change 100644 => 100755 src/model/operation/rnn.h diff --git a/python/singa/autograd.py b/python/singa/autograd.py index 1e3c1b8cbd..2365f16c44 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -969,34 +969,135 @@ class _RNN(Operation): def __init__(self, handle): self.handle = handle - def forward(self, X, W): + def forward(self, X, h0, c0, W): + # X of shape (seq_len, batch, input_size) + # h0_c0: (h0, c0) if lstm, else (h0,) + # h0, c0 of shape (num_layers * num_directions, batch, hidden_size) + if c0 is None: + assert self.rnn_mode != 'lstm' + c0= CTensor([]) # CTensor([]) and Tensor cx are the same? if self.handle.device_id == -1: raise NotImplementedError else: if training: - out, self.cache = singa.GpuRNNForwardTraining( - self.handle, X, W) + Y, hout, cout = singa.GpuRNNForwardTraining( + self.handle, X, h0, c0, W) + self.cache=(X, Y, h0, c0, W) else: - out = singa.GpuRNNForwardInference(self.handle, X, W) - return out + Y, hout, cout = singa.GpuRNNForwardInference( + self.handle, X, h0, c0, W) + + # Y of shape (seq_len, batch, hidden_size * num_directions) + # hout_cout: (hout, cout) if lstm, else (hout,) + # hout, cout of shape (num_layers * num_directions, batch, + # hidden_size) + oututs= 1dTo3d(Y) + + if self.rnn_mode != 'lstm': + return outputs, hout + else: + return outputs, hout, cout - def backward(self, dY): + def backward(self, dY, dh, dc=CTensor([])): assert training is True and hasattr( self, 'cache'), 'Please set training as True before do BP. ' - if dY.device().id() != self.handle.device_id: - dY.ToDevice(self.inputs[0].device()) + dY_1d= 3dTo1d(dY) + + if dY_1d.device().id() != self.handle.device_id: + dY_1d.ToDevice(self.cache[0].device()) if self.handle.device_id == -1: raise NotImplementedError else: - dX, dW = singa.GpuRNNBackward(self.handle, dY, self.cache) - return dX, dW + dX_1d, dhout, dcout, dW = singa.GpuRNNBackward( + self.handle, dY_1d, dh, dc, self.cache) + dX = 1dTo3d(dX_1d) -def rnn(): - pass + if self.rnn_mode != 'lstm': + return dX, dhout, dW + else: + return dX, dhout, dcout, dW + + +def rnn(handle, x, h0, c0, W): + return _RNN(handle)(x, h0, c0, W) class RNN(Layer): + + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False, rnn_mode='tanh'): + self.input_size = input_size + self.hidden_size = hidden_size + self.num_layers = num_layers + self.bias = bias + self.batch_first = batch_first + self.dropout = dropout + self.bidirectional = bidirectional + self.rnn_mode = rnn_mode + + if bias is not True or batch_first is not False: + raise NotImplementedError + + mult = 1 + if self.rnn_mode == 'tanh' or self.rnn_mode == 'relu': + mult *= 1 + elif self.rnn_mode == 'lstm': + mult *= 4 + elif self.rnn_mode == 'gru': + mult *= 3 + else: + raise ValueError + + if self.bidirectional: + mult *= 2 + + for k in range(num_layers): + if k == 1: + w_size = self.hidden_size * \ + (self.input_size + self.hidden_size + 2) + else: + w_size = self.hidden_size * \ + (self.hidden_size + self.hidden_size + 2) + W_Size *= mult * w_size + + self.W_Size = W_Size + self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True) + self.W.uniform(0.0, 1.0) + + def __call__(self, inputs, h0, c0=None): + # inputs of shape (seq_len, batch, input_size) + # h0_c0: (h0, c0) if lstm, else (h0,) + # h0, c0 of shape (num_layers * num_directions, batch, hidden_size) + + self.device_check(inputs, h0, self.W) + + if self.rnn_mode == 'lstm': + assert c0 is not None, 'Please input c0.' + self.device_check(h0, c0) + + self.handle = signa.CudnnRNNHandle(inputs.data, *SOME_PARAMETERS*) + self.handle.device_id = inputs.device.id() + + X= 3dTo1d(inputs) + outputs = rnn(self.handle, X, h0, c0, self.W) + return outputs + +def 3dTo1d(self, inputs): + pass + +def 1dTo3d(self, *args): + pass + +class LSTM(RNN): + + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False): + super(LSTM, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional,rnn_mode='lstm') + + +class GRU(RNN): + + def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False): + super(GRU, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional,rnn_mode='gru') diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc old mode 100644 new mode 100755 index afeba67686..79c397d196 --- a/src/model/operation/rnn.cc +++ b/src/model/operation/rnn.cc @@ -263,24 +263,21 @@ vector SplitOutput(size_t num, size_t dim, return outputs; }; -std::vector> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W) { - DataType dtype = inputs.at(0).data_type(); - auto dev = inputs.at(0).device(); +std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) { + DataType dtype = input.data_type(); + auto dev = input.at(0).device(); - CHECK_GT(inputs.size(), 1u + crh.has_cell_); - size_t num_x = inputs.size() - crh.has_cell_ - 1; - Tensor input = MergeInputs(num_x, inputs); Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; Tensor output(outshape, dev, dtype); // LOG(INFO) << "output size " << output.Size(); - Tensor hx = inputs.at(num_x); + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + CHECK_EQ(hx.shape(), state_shape); Tensor hy(state_shape, dev, dtype); - Tensor cy, cx; + Tensor cy; if (crh.has_cell_) { - cx = inputs.at(num_x + 1); cy.ResetLike(hy); } @@ -330,39 +327,23 @@ std::vector> GpuRNNForwardTraining(const CudnnRNNHandle &crh }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace, rspace}); - auto outputs = - SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output); - outputs.push_back(hy); - if (crh.has_cell_) outputs.push_back(cy); - - std::vector cache; - cache.push_back(input); - cache.push_back(output); - cache.push_back(hx); - cache.push_back(cx); - cache.push_back(W); - - return {outputs, cache}; + return {output, hy, cy}; }; -std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W) { - DataType dtype = inputs.at(0).data_type(); - auto dev = inputs.at(0).device(); - - CHECK_GT(inputs.size(), 1u + crh.has_cell_); - size_t num_x = inputs.size() - crh.has_cell_ - 1; - Tensor input = MergeInputs(num_x, inputs); +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) { + DataType dtype = input.data_type(); + auto dev = input.device(); Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; Tensor output(outshape, dev, dtype); // LOG(INFO) << "output size " << output.Size(); - Tensor hx = inputs.at(num_x); + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + CHECK_EQ(hx.shape(), state_shape); Tensor hy(state_shape, dev, dtype); - Tensor cy, cx; + Tensor cy; if (crh.has_cell_) { - cx = inputs.at(num_x + 1); cy.ResetLike(hy); } @@ -405,15 +386,10 @@ std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const vect // clang-format on }, {inb, wb, hxb, cxb}, {outb, hyb, cyb, wspace}); - auto outputs = - SplitOutput(num_x, crh.hidden_size_ * crh.num_directions_, inputs, output); - outputs.push_back(hy); - if (crh.has_cell_) outputs.push_back(cy); - - return outputs; + return {output, hy, cy}; }; -std::pair, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector &grads, const vector &cache) { +std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const vector &dY, const Tensor &dh, const Tensor &dc, const vector &cache) { const Tensor x = cache[0]; const Tensor y = cache[1]; const Tensor hx = cache[2]; @@ -423,24 +399,24 @@ std::pair, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, cons auto dev = y.device(); auto dtype = y.data_type(); - CHECK_GT(grads.size(), 1u + crh.has_cell_); - size_t num_dy = grads.size() - crh.has_cell_ - 1; - CHECK_EQ(num_dy, crh.seq_length_); - const Tensor dy = MergeInputs(num_dy, grads); - CHECK_EQ(dy.Size(), y.Size()); - const Tensor dhy = grads.at(num_dy); - Tensor dcy; - if (crh.has_cell_) - dcy = grads.at(num_dy + 1); + + CHECK_EQ(dY.Size(), y.Size()); + Shape xshape{y.Size() * crh.input_size_ / crh.hidden_size_ / crh.num_directions_}; + CHECK_EQ(x.shape(), xshape) Tensor dx(xshape, dev, dtype); + Tensor dw(W.shape(), dev, dtype); + Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; + CHECK_EQ(hx.shape(), state_shape) Tensor dhx(state_shape, dev, dtype); + Tensor dcx; if (crh.has_cell_) dcx.ResetLike(dhx); + dw.SetValue(0.0f); Block *yb = y.block(), *dyb = dy.block(), *dhyb = dhy.block(), *dcyb = dcy.block(), *xb = x.block(), *cxb = cx.block(), @@ -483,12 +459,7 @@ std::pair, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, cons {yb, dyb, dhyb, dcyb, xb, wb, wspace, rspace}, {dxb, dwb, dhxb, dcxb, wspace, rspace}); - auto data_grads = SplitOutput(num_dy, crh.input_size_, grads, dx); - data_grads.push_back(dhx); - if (crh.has_cell_) - data_grads.push_back(dcx); - - return std::make_pair(data_grads, dw); + return {dx, dhx, dcx, dw}; }; #endif // USE_CUDNN diff --git a/src/model/operation/rnn.h b/src/model/operation/rnn.h old mode 100644 new mode 100755 index 0dbbac9974..7a90ff8995 --- a/src/model/operation/rnn.h +++ b/src/model/operation/rnn.h @@ -69,18 +69,17 @@ class CudnnRNNHandle: public RNNHandle { Tensor reserve_space_; Tensor dropout_state_; }; - Tensor MergeInputs(size_t num, const vector &in); vector SplitOutput(size_t num, size_t dim, const vector &in, const Tensor output); -std::vector> GpuRNNForwardTraining(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W); +std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) ; -std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const vector &inputs, const Tensor &W); +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W); -std::pair, Tensor> GpuRNNBackward(const CudnnRNNHandle &crh, const vector &grads, const vector &cache); +std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const vector &dY, const Tensor &dh, const Tensor &dc, const vector &cache); #endif // USE_CUDNN From 209d4127986f22d435d6431d4db13cf0749e1e5d Mon Sep 17 00:00:00 2001 From: xuewanqi Date: Fri, 20 Jul 2018 17:16:12 +0000 Subject: [PATCH 4/5] SINGA-386 Implement RNN operation for autograd - redesign some RNN related functions and their APIs. - Now the design of RNN operation is for mini-batch train. - related files can be built without error. --- python/singa/autograd.py | 24 +++--- src/api/model_operation.i | 28 ++++++- src/model/operation/rnn.cc | 147 +++++++++++++++++-------------------- src/model/operation/rnn.h | 23 +++--- 4 files changed, 123 insertions(+), 99 deletions(-) diff --git a/python/singa/autograd.py b/python/singa/autograd.py index 2365f16c44..de426f2840 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -992,8 +992,8 @@ def forward(self, X, h0, c0, W): # hout_cout: (hout, cout) if lstm, else (hout,) # hout, cout of shape (num_layers * num_directions, batch, # hidden_size) - oututs= 1dTo3d(Y) - + oututs= _1dTo3d(Y) + if self.rnn_mode != 'lstm': return outputs, hout else: @@ -1003,7 +1003,7 @@ def backward(self, dY, dh, dc=CTensor([])): assert training is True and hasattr( self, 'cache'), 'Please set training as True before do BP. ' - dY_1d= 3dTo1d(dY) + dY_1d= _3dTo1d(dY) if dY_1d.device().id() != self.handle.device_id: dY_1d.ToDevice(self.cache[0].device()) @@ -1014,7 +1014,7 @@ def backward(self, dY, dh, dc=CTensor([])): dX_1d, dhout, dcout, dW = singa.GpuRNNBackward( self.handle, dY_1d, dh, dc, self.cache) - dX = 1dTo3d(dX_1d) + dX = _1dTo3d(dX_1d) if self.rnn_mode != 'lstm': return dX, dhout, dW @@ -1064,7 +1064,7 @@ def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first W_Size *= mult * w_size self.W_Size = W_Size - self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True) + self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True) # TODO: assign value of Wi separately self.W.uniform(0.0, 1.0) def __call__(self, inputs, h0, c0=None): @@ -1078,17 +1078,23 @@ def __call__(self, inputs, h0, c0=None): assert c0 is not None, 'Please input c0.' self.device_check(h0, c0) - self.handle = signa.CudnnRNNHandle(inputs.data, *SOME_PARAMETERS*) + if not hasattr(self, 'handle'): + self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers, + self.rnn_mode, self.dropout, self.bidirectional, self.W_Size) + elif inputs.shape[0] != self.handle.seq_length_ or inputs.shape[1] != self.handle.batch_size_: + self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers, + self.rnn_mode, self.dropout, self.bidirectional, self.W_Size) + self.handle.device_id = inputs.device.id() - X= 3dTo1d(inputs) + X= _3dTo1d(inputs) outputs = rnn(self.handle, X, h0, c0, self.W) return outputs -def 3dTo1d(self, inputs): +def _3dTo1d(self, inputs): pass -def 1dTo3d(self, *args): +def _1dTo3d(self, *args): pass class LSTM(RNN): diff --git a/src/api/model_operation.i b/src/api/model_operation.i index 435ff1c502..90790f7825 100755 --- a/src/api/model_operation.i +++ b/src/api/model_operation.i @@ -7,7 +7,7 @@ #include "../src/model/operation/convolution.h" #include "../src/model/operation/batchnorm.h" #include "../src/model/operation/pooling.h" - +#include "../src/model/operation/rnn.h" %} namespace singa { @@ -51,6 +51,14 @@ class PoolingHandle { int pooled_width; }; +class RNNHandle { +public: + RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size); + + size_t batch_size_; + size_t seq_length_; +}; #if USE_CUDNN class CudnnConvHandle: public ConvHandle { @@ -106,6 +114,24 @@ Tensor GpuPoolingForward(const CudnnPoolingHandle &cph, const Tensor &x); Tensor GpuPoolingBackward(const CudnnPoolingHandle &cph, const Tensor &dy, const Tensor& x, const Tensor& y); + +class CudnnRNNHandle: public RNNHandle { +public: + CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size); + + size_t batch_size_; + size_t seq_length_; + +}; + +std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) ; + +std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W); + +std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const std::vector &cache); + + #endif // USE_CUDNN } //namespace singa diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc index 79c397d196..590006f830 100755 --- a/src/model/operation/rnn.cc +++ b/src/model/operation/rnn.cc @@ -2,8 +2,12 @@ namespace singa { -RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string Rnn_mode, const float Dropout, const bool bidirectional) { +RNNHandle::RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size) { + + CHECK_EQ(input.shape(2), Input_size); + batch_size_ = input.shape(1); + seq_length_= input.shape(0); input_size_ = Input_size; CHECK_GT(input_size_, 0u); @@ -28,60 +32,54 @@ RNNHandle::RNNHandle(const size_t Input_size, const size_t Hidden_size, const si } // the first constant (4) is the size of float // the second constant (2, 8, 6) is the number of sets of params - int mult = 1; - if (rnn_mode_ == "relu" || rnn_mode_ == "tanh") - mult *= 1; - else if (rnn_mode_ == "lstm") - mult *= 4; - else if (rnn_mode_ == "gru") - mult *= 3; - if (bidirectional) - mult *= 2; - - weight_size = 0; - for (size_t i = 0; i < num_stacks_; i++) { - size_t dim = hidden_size_ * (input_size_ + hidden_size_ + 2); - if (i > 0) - dim = hidden_size_ * (hidden_size_ + hidden_size_ + 2); - weight_size += mult * dim; - } + weight_size= Weight_size; + }; #ifdef USE_CUDNN -CudnnRNNHandle::CudnnRNNHandle(const vector &inputs, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string Rnn_mode, const float Dropout, const bool bidirectional): - RNNHandle(Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional) { +CudnnRNNHandle::CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size): + RNNHandle(input, Input_size, Hidden_size, Num_stacks, Rnn_mode, Dropout, bidirectional, Weight_size) { - CHECK_GT(inputs.size(), 1u + has_cell_); - size_t num_x = inputs.size() - has_cell_ - 1; - - DataType dtype = inputs.at(0).data_type(); - if (rnn_desc_ != nullptr) - CHECK_EQ(dtype_, GetCudnnDataType(dtype)) - << "Cannot change cudnn data type during training from " << dtype_ - << " to " << GetCudnnDataType(dtype); - else - dtype_ = GetCudnnDataType(dtype); + DataType dtype = input.data_type(); + dtype_ = GetCudnnDataType(dtype); - UpdateStates(num_x, inputs); + UpdateIODescriptors(input); + ResetHiddenAndCellDescriptors(); + SetRNNDescriptor(input.device()); + UpdateSpaces(seq_length_, input.device()); }; -void CudnnRNNHandle::UpdateStates(size_t num_x, const vector &inputs) { - UpdateIODescriptors(num_x, inputs); - size_t new_batch_size = inputs.at(0).shape(0); - if (batch_size_ != new_batch_size) - ResetHiddenAndCellDescriptors(new_batch_size); - if (rnn_desc_ == nullptr) - SetRNNDescriptor(inputs.at(0).device()); - UpdateSpaces(num_x, inputs.at(0).device()); - batch_size_ = new_batch_size; - seq_length_ = num_x; +CudnnRNNHandle::~CudnnRNNHandle() { + if (weight_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyFilterDescriptor(weight_desc_)); + if (dropout_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyDropoutDescriptor(dropout_desc_)); + if (rnn_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyRNNDescriptor(rnn_desc_)); + if (hx_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(hx_desc_)); + if (hy_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(hy_desc_)); + if (cx_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(cx_desc_)); + if (cy_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(cy_desc_)); + if (dhx_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhx_desc_)); + if (dhy_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dhy_desc_)); + if (dcx_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcx_desc_)); + if (dcy_desc_ != nullptr) + CUDNN_CHECK(cudnnDestroyTensorDescriptor(dcy_desc_)); + DestroyIODescriptors(); }; void CudnnRNNHandle::DestroyIODescriptors() { if (x_descs_ != nullptr) { - for (size_t i = 0; i < max_length_; i++) { + for (size_t i = 0; i < seq_length_; i++) { CUDNN_CHECK(cudnnDestroyTensorDescriptor(x_descs_[i])); CUDNN_CHECK(cudnnDestroyTensorDescriptor(dx_descs_[i])); } @@ -89,7 +87,7 @@ void CudnnRNNHandle::DestroyIODescriptors() { delete [] dx_descs_; } if (y_descs_ != nullptr) { - for (size_t i = 0; i < max_length_; i++) { + for (size_t i = 0; i < seq_length_; i++) { CUDNN_CHECK(cudnnDestroyTensorDescriptor(y_descs_[i])); CUDNN_CHECK(cudnnDestroyTensorDescriptor(dy_descs_[i])); } @@ -98,61 +96,60 @@ void CudnnRNNHandle::DestroyIODescriptors() { } }; -void CudnnRNNHandle::UpdateIODescriptors(size_t len, const vector &inputs) { - bool reset = false; - if (max_length_ < len) { - DestroyIODescriptors(); - max_length_ = len; - x_descs_ = new cudnnTensorDescriptor_t[len]; - dx_descs_ = new cudnnTensorDescriptor_t[len]; - y_descs_ = new cudnnTensorDescriptor_t[len]; - dy_descs_ = new cudnnTensorDescriptor_t[len]; - for (size_t i = 0; i < len; i++) { + +void CudnnRNNHandle::UpdateIODescriptors(const Tensor &input) { + x_descs_ = new cudnnTensorDescriptor_t[seq_length_]; + dx_descs_ = new cudnnTensorDescriptor_t[seq_length_]; + y_descs_ = new cudnnTensorDescriptor_t[seq_length_]; + dy_descs_ = new cudnnTensorDescriptor_t[seq_length_]; + for (size_t i = 0; i < seq_length_; i++) { CUDNN_CHECK(cudnnCreateTensorDescriptor(&x_descs_[i])); CUDNN_CHECK(cudnnCreateTensorDescriptor(&dx_descs_[i])); CUDNN_CHECK(cudnnCreateTensorDescriptor(&y_descs_[i])); CUDNN_CHECK(cudnnCreateTensorDescriptor(&dy_descs_[i])); } - reset = true; - } - for (size_t i = 0; i < len; i++) { - CHECK_EQ(inputs[i].shape(1), input_size_); - if (inputs[i].shape(0) != batch_size_ || reset) { + for (size_t i = 0; i < seq_length_; i++) { + CHECK_EQ(input.shape(2), input_size_); int d[3] = {1, 1, 1}, s[3] = {1, 1, 1}; - d[0] = static_cast(inputs[i].shape(0)); + d[0] = static_cast(batch_size_); CHECK_GT(d[0], 0); - d[1] = static_cast(inputs[i].shape(1)); + d[1] = static_cast(input_size_); s[0] = d[1] * d[2]; s[1] = d[2]; CUDNN_CHECK(cudnnSetTensorNdDescriptor(x_descs_[i], dtype_, 3, d, s)); CUDNN_CHECK(cudnnSetTensorNdDescriptor(dx_descs_[i], dtype_, 3, d, s)); - d[0] = static_cast(inputs[i].shape(0)); + d[0] = static_cast(batch_size_); d[1] = static_cast(hidden_size_ * num_directions_); s[0] = d[1] * d[2]; s[1] = d[2]; CUDNN_CHECK(cudnnSetTensorNdDescriptor(y_descs_[i], dtype_, 3, d, s)); CUDNN_CHECK(cudnnSetTensorNdDescriptor(dy_descs_[i], dtype_, 3, d, s)); } - } }; -void CudnnRNNHandle::ResetHiddenAndCellDescriptors(size_t batch_size) { - if (batch_size_ == 0) { +void CudnnRNNHandle::ResetHiddenAndCellDescriptors() { + if (cx_desc_ == nullptr) CUDNN_CHECK(cudnnCreateTensorDescriptor(&cx_desc_)); + if (dcx_desc_ == nullptr) CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcx_desc_)); + if (cy_desc_ == nullptr) CUDNN_CHECK(cudnnCreateTensorDescriptor(&cy_desc_)); + if (dcy_desc_ == nullptr) CUDNN_CHECK(cudnnCreateTensorDescriptor(&dcy_desc_)); + if (hx_desc_ == nullptr) CUDNN_CHECK(cudnnCreateTensorDescriptor(&hx_desc_)); + if (dhx_desc_ == nullptr) CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhx_desc_)); + if (hy_desc_ == nullptr) CUDNN_CHECK(cudnnCreateTensorDescriptor(&hy_desc_)); + if (dhy_desc_ == nullptr) CUDNN_CHECK(cudnnCreateTensorDescriptor(&dhy_desc_)); - } int dim[3] = {1, 1, 1}; dim[0] = static_cast(num_stacks_ * num_directions_); - dim[1] = static_cast(batch_size); + dim[1] = static_cast(batch_size_); dim[2] = static_cast(hidden_size_); int stride[3] = {1, 1, 1}; stride[0] = dim[1] * dim[2]; @@ -229,7 +226,7 @@ void CudnnRNNHandle::UpdateSpaces(size_t seq_length, shared_ptr dev) { reserve_space_ = Tensor(Shape{count}, dev, kChar); // reserve_space_.SetValue(0); } -} +}; Tensor MergeInputs(size_t num, const vector &in) { if (num == 1) @@ -265,7 +262,7 @@ vector SplitOutput(size_t num, size_t dim, std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W) { DataType dtype = input.data_type(); - auto dev = input.at(0).device(); + auto dev = input.device(); Shape outshape{input.Size() * crh.hidden_size_ / crh.input_size_ * crh.num_directions_}; @@ -273,7 +270,6 @@ std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tenso // LOG(INFO) << "output size " << output.Size(); Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; - CHECK_EQ(hx.shape(), state_shape); Tensor hy(state_shape, dev, dtype); Tensor cy; @@ -339,7 +335,6 @@ std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens // LOG(INFO) << "output size " << output.Size(); Shape state_shape{crh.num_stacks_ * crh.num_directions_, crh.batch_size_, crh.hidden_size_}; - CHECK_EQ(hx.shape(), state_shape); Tensor hy(state_shape, dev, dtype); Tensor cy; @@ -389,7 +384,7 @@ std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tens return {output, hy, cy}; }; -std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const vector &dY, const Tensor &dh, const Tensor &dc, const vector &cache) { +std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const std::vector &cache) { const Tensor x = cache[0]; const Tensor y = cache[1]; const Tensor hx = cache[2]; @@ -399,18 +394,14 @@ std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const vector GpuRNNBackward(const CudnnRNNHandle &crh, const vector &inputs, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, - const std::string Rnn_mode, const float Dropout, const bool bidirectional); - void UpdateStates(size_t num_x, const vector &inputs); + CudnnRNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, + const std::string Rnn_mode, const float Dropout, const bool bidirectional, const size_t Weight_size); + ~CudnnRNNHandle(); + void DestroyIODescriptors(); - void UpdateIODescriptors(size_t len, const vector &inputs); - void ResetHiddenAndCellDescriptors(size_t batch_size); + void UpdateIODescriptors(const Tensor &input); + void ResetHiddenAndCellDescriptors() ; void SetRNNDescriptor(shared_ptr dev); void UpdateSpaces(size_t seq_length, shared_ptr dev); @@ -79,7 +80,7 @@ std::vector GpuRNNForwardTraining(const CudnnRNNHandle &crh, const Tenso std::vector GpuRNNForwardInference(const CudnnRNNHandle &crh, const Tensor &input, const Tensor &hx, const Tensor &cx, const Tensor &W); -std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const vector &dY, const Tensor &dh, const Tensor &dc, const vector &cache); +std::vector GpuRNNBackward(const CudnnRNNHandle &crh, const Tensor &dY, const Tensor &dhy, const Tensor &dcy, const vector &cache); #endif // USE_CUDNN From 4a141014a2ac629b5719fee1aa1f6ad6f5177263 Mon Sep 17 00:00:00 2001 From: xuewanqi Date: Thu, 16 Aug 2018 14:35:32 +0000 Subject: [PATCH 5/5] SINGA-386 Implement RNN operation for autograd - Fix some bugs and do some design modification for RNN, LSTM, GRU..., which are calculated by call CUDNN funcitons. The implemented layers all works well and can pass shape check for both forward step and backward step. --- python/singa/autograd.py | 70 +++++++++++++++++++++++--------------- src/api/model_operation.i | 6 ++++ src/model/operation/rnn.cc | 3 +- 3 files changed, 51 insertions(+), 28 deletions(-) diff --git a/python/singa/autograd.py b/python/singa/autograd.py index de426f2840..7469d6ff40 100755 --- a/python/singa/autograd.py +++ b/python/singa/autograd.py @@ -23,6 +23,7 @@ import numpy as np import math +from singa import tensor from .tensor import Tensor from . import layer from singa.proto import model_pb2 @@ -969,12 +970,13 @@ class _RNN(Operation): def __init__(self, handle): self.handle = handle - def forward(self, X, h0, c0, W): + #def forward(self, X, h0, c0, W): + def forward(self, X, h0, W, c0=None): # X of shape (seq_len, batch, input_size) # h0_c0: (h0, c0) if lstm, else (h0,) # h0, c0 of shape (num_layers * num_directions, batch, hidden_size) if c0 is None: - assert self.rnn_mode != 'lstm' + assert self.handle.rnn_mode_ != 'lstm' c0= CTensor([]) # CTensor([]) and Tensor cx are the same? if self.handle.device_id == -1: @@ -992,38 +994,49 @@ def forward(self, X, h0, c0, W): # hout_cout: (hout, cout) if lstm, else (hout,) # hout, cout of shape (num_layers * num_directions, batch, # hidden_size) - oututs= _1dTo3d(Y) + + #oututs= _1dTo3d(Y) + shape=(self.handle.seq_length_, self.handle.batch_size_, self.handle.hidden_size_) + outputs = singa.Reshape(Y, shape) - if self.rnn_mode != 'lstm': + if self.handle.rnn_mode_ != 'lstm': return outputs, hout else: return outputs, hout, cout - def backward(self, dY, dh, dc=CTensor([])): + def backward(self, dY, dh=CTensor([]), dc=CTensor([])): assert training is True and hasattr( self, 'cache'), 'Please set training as True before do BP. ' - dY_1d= _3dTo1d(dY) + #dY_1d= _3dTo1d(dY) - if dY_1d.device().id() != self.handle.device_id: - dY_1d.ToDevice(self.cache[0].device()) + if dY.device().id() != self.handle.device_id: + dY.ToDevice(self.cache[0].device()) if self.handle.device_id == -1: raise NotImplementedError else: dX_1d, dhout, dcout, dW = singa.GpuRNNBackward( - self.handle, dY_1d, dh, dc, self.cache) + self.handle, dY, dh, dc, self.cache) - dX = _1dTo3d(dX_1d) + #dX = _1dTo3d(dX_1d) + shape=(self.handle.seq_length_, self.handle.batch_size_, self.handle.input_size_) + dX = singa.Reshape(dX_1d, shape) - if self.rnn_mode != 'lstm': + if self.handle.rnn_mode_ != 'lstm': return dX, dhout, dW else: - return dX, dhout, dcout, dW + return dX, dhout, dW, dcout -def rnn(handle, x, h0, c0, W): - return _RNN(handle)(x, h0, c0, W) +#def rnn(handle, x, h0, c0, W): + # return _RNN(handle)(x, h0, c0, W) + +def rnn(handle, x, h0, W, c0): + if c0 is None: + return _RNN(handle)(x, h0, W) + else: + return _RNN(handle)(x, h0, W, c0) class RNN(Layer): @@ -1054,14 +1067,15 @@ def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first if self.bidirectional: mult *= 2 + W_Size = 0 for k in range(num_layers): - if k == 1: + if k == 0: w_size = self.hidden_size * \ (self.input_size + self.hidden_size + 2) else: w_size = self.hidden_size * \ (self.hidden_size + self.hidden_size + 2) - W_Size *= mult * w_size + W_Size += mult * w_size self.W_Size = W_Size self.W = Tensor(shape=(W_Size,), requires_grad=True, stores_grad=True) # TODO: assign value of Wi separately @@ -1077,33 +1091,35 @@ def __call__(self, inputs, h0, c0=None): if self.rnn_mode == 'lstm': assert c0 is not None, 'Please input c0.' self.device_check(h0, c0) + else: + assert c0 is None, 'only lstm needs input c0' if not hasattr(self, 'handle'): - self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers, + self.handle = singa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers, self.rnn_mode, self.dropout, self.bidirectional, self.W_Size) elif inputs.shape[0] != self.handle.seq_length_ or inputs.shape[1] != self.handle.batch_size_: - self.handle = signa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers, + self.handle = singa.CudnnRNNHandle(inputs.data, self.input_size, self.hidden_size, self.num_layers, self.rnn_mode, self.dropout, self.bidirectional, self.W_Size) self.handle.device_id = inputs.device.id() - X= _3dTo1d(inputs) - outputs = rnn(self.handle, X, h0, c0, self.W) + #X= _3dTo1d(inputs) + X=inputs + outputs = rnn(self.handle, X, h0, self.W, c0) + #outputs = rnn(self.handle, X, h0, self.W) + #outputs=tensor.to_numpy(outputs[0]) + #print(outputs.shape) + #print(outputs) return outputs -def _3dTo1d(self, inputs): - pass - -def _1dTo3d(self, *args): - pass class LSTM(RNN): def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False): - super(LSTM, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional,rnn_mode='lstm') + super(LSTM, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, rnn_mode='lstm') class GRU(RNN): def __init__(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0, bidirectional=False): - super(GRU, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional,rnn_mode='gru') + super(GRU, self).__init__(input_size, hidden_size, num_layers, bias, batch_first, dropout, bidirectional, rnn_mode='gru') diff --git a/src/api/model_operation.i b/src/api/model_operation.i index 90790f7825..5cec92e4ad 100755 --- a/src/api/model_operation.i +++ b/src/api/model_operation.i @@ -58,6 +58,9 @@ public: size_t batch_size_; size_t seq_length_; + size_t input_size_; + size_t hidden_size_; + std::string rnn_mode_; }; #if USE_CUDNN @@ -122,6 +125,9 @@ public: size_t batch_size_; size_t seq_length_; + size_t input_size_; + size_t hidden_size_; + std::string rnn_mode_; }; diff --git a/src/model/operation/rnn.cc b/src/model/operation/rnn.cc index 590006f830..e8c614e072 100755 --- a/src/model/operation/rnn.cc +++ b/src/model/operation/rnn.cc @@ -1,5 +1,5 @@ #include "./rnn.h" - +#include namespace singa { RNNHandle::RNNHandle(const Tensor &input, const size_t Input_size, const size_t Hidden_size, const size_t Num_stacks, @@ -203,6 +203,7 @@ void CudnnRNNHandle::SetRNNDescriptor(shared_ptr dev) { CUDNN_CHECK(cudnnGetRNNParamsSize(ctx->cudnn_handle, rnn_desc_, x_descs_[0], &weight_size_, dtype_)); // check the size manually calculated + //std::cout<(weight_size_), 1, 1}; CUDNN_CHECK(cudnnCreateFilterDescriptor(&weight_desc_));