Skip to content

Commit

Permalink
Add weight prepacking to LSTM kernel (#5305)
Browse files Browse the repository at this point in the history
  • Loading branch information
tracysh authored Sep 29, 2020
1 parent 11c194c commit f07059c
Show file tree
Hide file tree
Showing 7 changed files with 213 additions and 91 deletions.
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/rnn/deep_cpu_gru.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ Status DeepCpuGruOp::ComputeImpl(OpKernelContext& context) const {
int batch_size = gsl::narrow<int>(X_shape[1]);
int input_size = gsl::narrow<int>(X_shape[2]);

auto status = ValidateCommonRnnInputs(X, W, R, B, 3, sequence_lens, initial_h, num_directions_, hidden_size_);
auto status = ValidateCommonRnnInputs(X, W.Shape(), R.Shape(), B, 3, sequence_lens, initial_h, num_directions_, hidden_size_);
ORT_RETURN_IF_ERROR(status);

// GRU outputs are optional but must be in the same order
Expand Down
131 changes: 96 additions & 35 deletions onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,8 @@ class UniDirectionalLstm {
const ActivationFuncs::Entry& activation_func_h, float clip, concurrency::ThreadPool* thread_pool);

void Compute(const gsl::span<const T>& inputs, const gsl::span<const int>& sequence_lengths, int num_directions,
const gsl::span<const T>& input_weights, const gsl::span<const T>& recurrent_weights,
gsl::span<T>& outputs, gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state);
const GemmWeights<T>& input_weights, const GemmWeights<T>& recurrent_weights, gsl::span<T>& outputs,
gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state);

~UniDirectionalLstm() = default;

Expand Down Expand Up @@ -290,20 +290,74 @@ class UniDirectionalLstm {

} // namespace detail

Status DeepCpuLstmOp::TryPackWeights(const Tensor& weights, PackedWeights& packed_weights, bool& is_packed) {
const auto& shape = weights.Shape();
if (shape.NumDimensions() != 3) {
return Status::OK();
}

// weights: [num_directions, 4*hidden_size, input_size]
// recurrence weights: [num_directions, 4*hidden_size, hidden_size]
const size_t N = static_cast<size_t>(shape[1]);
const size_t K = static_cast<size_t>(shape[2]);

if ((shape[0] != num_directions_) || (N != static_cast<size_t>(hidden_size_ * 4))) {
return Status::OK();
}

const size_t packed_weights_size = MlasGemmPackBSize(N, K);
if (packed_weights_size == 0) {
return Status::OK();
}

auto alloc = Info().GetAllocator(0, OrtMemTypeDefault);
auto* packed_weights_data = alloc->Alloc(SafeInt<size_t>(packed_weights_size) * num_directions_);
packed_weights.buffer_ = BufferUniquePtr(packed_weights_data, BufferDeleter(alloc));
packed_weights.weights_size_ = packed_weights_size;
packed_weights.shape_ = shape;

const auto* weights_data = weights.Data<float>();
for (int i = 0; i < num_directions_; i++) {
MlasGemmPackB(CblasTrans, N, K, weights_data, K, packed_weights_data);
packed_weights_data = static_cast<uint8_t*>(packed_weights_data) + packed_weights_size;
weights_data += N * K;
}

is_packed = true;
return Status::OK();
}

#if !defined(USE_MKLML_FOR_BLAS)
Status DeepCpuLstmOp::PrePack(const Tensor& tensor, int input_idx, bool& is_packed) {
is_packed = false;

if (tensor.IsDataType<float>()) {
if (input_idx == 1) {
return TryPackWeights(tensor, packed_W_, is_packed);
} else if (input_idx == 2) {
return TryPackWeights(tensor, packed_R_, is_packed);
}
}

return Status::OK();
}
#endif

Status DeepCpuLstmOp::Compute(OpKernelContext* context) const {
const Tensor& X = *context->Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]

Status status;
// auto& logger = context->Logger();

if (X.IsDataType<float>())
if (X.IsDataType<float>()) {
status = ComputeImpl<float>(*context);
else if (X.IsDataType<double>()) {
} else if (X.IsDataType<double>()) {
/* Need to update all the helpers to support double...
status = ComputeImpl<double>(*context); */
ORT_NOT_IMPLEMENTED("LSTM operator does not support double yet");
} else
} else {
ORT_THROW("Invalid data type for LSTM operator of ", X.DataType());
}

return status;
}
Expand All @@ -322,8 +376,10 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
auto& logger = context.Logger();

const Tensor& X = *context.Input<Tensor>(0); // inputs. [seq_length, batch_size, input_size]
const Tensor& W = *context.Input<Tensor>(1); // weights. [num_directions, 4*hidden_size, input_size]
const Tensor& R = *context.Input<Tensor>(2); // recurrence weights. [num_directions, 4*hidden_size, hidden_size]
const Tensor* W = packed_W_.buffer_ ? nullptr : context.Input<Tensor>(1);
// weights. [num_directions, 4*hidden_size, input_size]
const Tensor* R = packed_R_.buffer_ ? nullptr : context.Input<Tensor>(2);
// recurrence weights. [num_directions, 4*hidden_size, hidden_size]

// optional
const Tensor* B = context.Input<Tensor>(3); // bias. [num_directions, 8*hidden_size]
Expand All @@ -332,13 +388,16 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
const Tensor* initial_c = context.Input<Tensor>(6); // initial cell. [num_directions, batch_size, hidden_size]
const Tensor* P = context.Input<Tensor>(7); // peephole weights. [num_directions, 3*hidden_size]

auto& X_shape = X.Shape();
const auto& X_shape = X.Shape();

int seq_length = gsl::narrow<int>(X_shape[0]);
int batch_size = gsl::narrow<int>(X_shape[1]);
int input_size = gsl::narrow<int>(X_shape[2]);

Status status = ValidateInputs(X, W, R, B, sequence_lens, initial_h, initial_c, P, batch_size);
const auto& W_shape = (W != nullptr) ? W->Shape() : packed_W_.shape_;
const auto& R_shape = (R != nullptr) ? R->Shape() : packed_R_.shape_;

Status status = ValidateInputs(X, W_shape, R_shape, B, sequence_lens, initial_h, initial_c, P, batch_size);
ORT_RETURN_IF_ERROR(status);

// LSTM outputs are optional but must be in the same order
Expand Down Expand Up @@ -370,8 +429,9 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
status = context.GetTempSpaceAllocator(&alloc);
ORT_RETURN_IF_ERROR(status);

gsl::span<const T> input_weights = W.DataAsSpan<T>();
gsl::span<const T> recurrent_weights = R.DataAsSpan<T>();
const auto* input_weights = (W != nullptr) ? W->Data<T>() : nullptr;
const auto* recurrent_weights = (R != nullptr) ? R->Data<T>() : nullptr;

gsl::span<const T> bias = B != nullptr ? B->DataAsSpan<T>() : gsl::span<const T>();
gsl::span<const T> peephole_weights = P != nullptr ? P->DataAsSpan<T>() : gsl::span<const T>();

Expand All @@ -381,8 +441,9 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
const size_t bias_size_per_direction = 8 * hidden_size_;
const size_t peephole_weights_size_per_direction = 3 * hidden_size_;

gsl::span<const T> input_weights_1 = input_weights.subspan(0, input_weights_size_per_direction);
gsl::span<const T> recurrent_weights_1 = recurrent_weights.subspan(0, hidden_weights_size_per_direction);
GemmWeights<T> input_weights_1(0, input_weights, input_weights_size_per_direction, packed_W_);
GemmWeights<T> recurrent_weights_1(0, recurrent_weights, hidden_weights_size_per_direction, packed_R_);

gsl::span<const T> bias_1 = bias.empty() ? bias : bias.subspan(0, bias_size_per_direction);
gsl::span<const T> peephole_weights_1 =
peephole_weights.empty() ? peephole_weights : peephole_weights.subspan(0, peephole_weights_size_per_direction);
Expand Down Expand Up @@ -427,11 +488,10 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
gsl::span<T> last_cell_1 = last_cell.subspan(0, last_cell_size_per_direction);

if (direction_ == Direction::kBidirectional) {
GemmWeights<T> input_weights_2(1, input_weights, input_weights_size_per_direction, packed_W_);
GemmWeights<T> recurrent_weights_2(1, recurrent_weights, hidden_weights_size_per_direction, packed_R_);

// spans for second direction
gsl::span<const T> input_weights_2 =
input_weights.subspan(input_weights_size_per_direction, input_weights_size_per_direction);
gsl::span<const T> hidden_weights_2 =
recurrent_weights.subspan(hidden_weights_size_per_direction, hidden_weights_size_per_direction);
gsl::span<const T> bias_2 = bias.empty() ? bias : bias.subspan(bias_size_per_direction, bias_size_per_direction);
gsl::span<const T> peephole_weights_2 =
peephole_weights.empty() ? peephole_weights : peephole_weights.subspan(peephole_weights_size_per_direction, peephole_weights_size_per_direction);
Expand Down Expand Up @@ -459,8 +519,8 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {

fw.Compute(input, sequence_lens_span, num_directions_, input_weights_1, recurrent_weights_1, output_1,
hidden_output_1, last_cell_1);
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, hidden_weights_2, output_2, hidden_output_2,
last_cell_2);
bw.Compute(input, sequence_lens_span, num_directions_, input_weights_2, recurrent_weights_2, output_2,
hidden_output_2, last_cell_2);
} else {
detail::UniDirectionalLstm<T> fw(alloc, logger, seq_length, batch_size, input_size, hidden_size_, direction_,
input_forget_, bias_1, peephole_weights_1, initial_hidden_1, initial_cell_1,
Expand All @@ -481,11 +541,11 @@ Status DeepCpuLstmOp::ComputeImpl(OpKernelContext& context) const {
return Status::OK();
}

Status DeepCpuLstmOp::ValidateInputs(const Tensor& X, const Tensor& W, const Tensor& R, const Tensor* B,
const Tensor* sequence_lens, const Tensor* initial_h, const Tensor* initial_c,
const Tensor* P, int batch_size) const {
Status DeepCpuLstmOp::ValidateInputs(const Tensor& X, const TensorShape& W_shape, const TensorShape& R_shape,
const Tensor* B, const Tensor* sequence_lens, const Tensor* initial_h,
const Tensor* initial_c, const Tensor* P, int batch_size) const {
auto status =
rnn::detail::ValidateCommonRnnInputs(X, W, R, B, 4, sequence_lens, initial_h, num_directions_, hidden_size_);
rnn::detail::ValidateCommonRnnInputs(X, W_shape, R_shape, B, 4, sequence_lens, initial_h, num_directions_, hidden_size_);
ORT_RETURN_IF_ERROR(status);

if (initial_c != nullptr) {
Expand Down Expand Up @@ -680,8 +740,8 @@ void UniDirectionalLstm<T>::LoadBias(const gsl::span<const T>& WbRb_values) {
template <typename T>
void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
const gsl::span<const int>& sequence_lengths_arg, const int num_directions,
const gsl::span<const T>& input_weights,
const gsl::span<const T>& recurrent_weights, gsl::span<T>& outputs,
const GemmWeights<T>& input_weights, const GemmWeights<T>& recurrent_weights,
gsl::span<T>& outputs,
gsl::span<T>& final_hidden_state, gsl::span<T>& final_cell_state) {
// copy spans (just T* and size, not data in span) as we may change them
gsl::span<const T> inputs = inputs_arg;
Expand Down Expand Up @@ -736,9 +796,9 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
const int total_rows = max_sequence_length * batch_size_;

// apply the weights to all the inputs and save to output_IOFC
ComputeGemm(total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin(), inputs.cend(), input_size_,
input_weights.cbegin(), input_weights.cend(), // W[iofc]
input_size_, beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4, thread_pool_);
ComputeGemm(total_rows, hidden_size_x4, input_size_, alpha, inputs.cbegin(), inputs.cend(),
input_weights,
beta, output_iofc_.begin(), output_iofc_.end(), hidden_size_x4, thread_pool_);

DumpMatrix("Xt*(W[iofc]^T)", output_iofc_.data(), total_rows, hidden_size_x4);

Expand Down Expand Up @@ -783,10 +843,10 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,

// calculate Xt*(W[iofc]^T) + Ht-t*R[iofc]
// Do it sequentially to avoid nested parallelism
ComputeGemm(local_fused_hidden_rows, hidden_size_x4, hidden_size_, alpha, previous_state,
previous_state_end, // Ht-1
hidden_size_, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
ComputeGemm(local_fused_hidden_rows, hidden_size_x4, hidden_size_, alpha,
previous_state, previous_state_end, // Ht-1
recurrent_weights, // R[iofc]
beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, nullptr);

DumpMatrix("Xt*(W[iofc]^T) + Ht-t*R[iofc]" + row_str, &*step_out_IOFC, local_fused_hidden_rows, hidden_size_x4);
Expand Down Expand Up @@ -861,9 +921,10 @@ void UniDirectionalLstm<T>::Compute(const gsl::span<const T>& inputs_arg,
span_T_iter step_out_IOFC = output_iofc_.begin() + (step * batch_size_) * hidden_size_x4;

// calculate Xt*(W[iofc]^T) + Ht-t*R[iofc]
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, alpha, previous_state, previous_state_end, // Ht-1
hidden_size_, recurrent_weights.cbegin(), recurrent_weights.cend(), // R[iofc]
hidden_size_, beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
ComputeGemm(batch_size_, hidden_size_x4, hidden_size_, alpha,
previous_state, previous_state_end, // Ht-1
recurrent_weights, // R[iofc]
beta, step_out_IOFC, output_iofc_.end(), // input contains Xt*(W[iofc]^T)
hidden_size_x4, thread_pool_);

span_T_iter batched_output;
Expand Down
12 changes: 10 additions & 2 deletions onnxruntime/core/providers/cpu/rnn/deep_cpu_lstm.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,22 @@ class DeepCpuLstmOp final : public OpKernel {
activation_func_betas);
}

#if !defined(USE_MKLML_FOR_BLAS)
Status PrePack(const Tensor& tensor, int input_idx, bool& is_packed) override;
#endif
Status Compute(OpKernelContext* context) const override;

~DeepCpuLstmOp() override = default;

private:
Status TryPackWeights(const Tensor& weights, rnn::detail::PackedWeights& packed_weights, bool& is_packed);

template <typename T>
Status ComputeImpl(OpKernelContext& context) const;

Status ValidateInputs(const Tensor& X,
const Tensor& W,
const Tensor& R,
const TensorShape& W,
const TensorShape& R,
const Tensor* B,
const Tensor* sequence_lens,
const Tensor* initial_h,
Expand All @@ -75,6 +80,9 @@ class DeepCpuLstmOp final : public OpKernel {
float clip_;
bool input_forget_ = false;

rnn::detail::PackedWeights packed_W_;
rnn::detail::PackedWeights packed_R_;

rnn::detail::ActivationFuncs activation_funcs_;

};
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/cpu/rnn/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ Status RNN<float>::Compute(OpKernelContext* ctx) const {
int64_t batch_size = X.Shape()[1];
int64_t input_size = X.Shape()[2];

auto status = rnn::detail::ValidateCommonRnnInputs(X, W, R, B, 1, sequence_lens, initial_h,
auto status = rnn::detail::ValidateCommonRnnInputs(X, W.Shape(), R.Shape(), B, 1, sequence_lens, initial_h,
num_directions, hidden_size_);
ORT_RETURN_IF_ERROR(status);

Expand Down
6 changes: 2 additions & 4 deletions onnxruntime/core/providers/cpu/rnn/rnn_helpers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@ namespace detail {
using namespace ::onnxruntime::common;

Status ValidateCommonRnnInputs(const Tensor& X,
const Tensor& W,
const Tensor& R,
const TensorShape& W_shape,
const TensorShape& R_shape,
const Tensor* B,
int WRB_dim_1_multipler,
const Tensor* sequence_lens,
const Tensor* initial_h,
int64_t num_directions,
int64_t hidden_size) {
auto& X_shape = X.Shape();
auto& W_shape = W.Shape();
auto& R_shape = R.Shape();

int64_t seq_length = X_shape[0];
int64_t batch_size = X_shape[1];
Expand Down
Loading

0 comments on commit f07059c

Please sign in to comment.