Skip to content

Commit 0d3e750

Browse files
authored
Improve performance of _to_dim_order_copy (#11545)
We had a problem similar to the one motivating BroadcastIndexesRange: we were doing full coordinate recalcuation on every iteration of a loop. Thankfully, it is straightforward to use BroadcastIndexesRange (and the new support_noncontiguous_tensors option, which I realized should be renamed) to fix this.
1 parent f476f50 commit 0d3e750

File tree

5 files changed

+34
-52
lines changed

5 files changed

+34
-52
lines changed

kernels/portable/cpu/op__to_dim_order_copy.cpp

Lines changed: 11 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#include <c10/util/irange.h>
1010

1111
#include <executorch/kernels/portable/cpu/scalar_utils.h>
12+
#include <executorch/kernels/portable/cpu/util/broadcast_util.h>
1213
#include <executorch/kernels/portable/cpu/util/copy_ops_util.h>
13-
#include <executorch/runtime/core/exec_aten/util/dim_order_util.h>
1414
#include <executorch/runtime/kernel/kernel_includes.h>
1515

1616
namespace torch {
@@ -31,47 +31,21 @@ using Optional = std::optional<T>;
3131

3232
namespace {
3333

34-
// TODO(T179241236): Update core/exec_aten/util/tensor_util.h to support dim
35-
// order other than contiguous.
36-
int64_t coordinateToIndexWithDimOrder(
37-
const Tensor& self,
38-
const size_t* cur_indices) {
39-
int64_t index = 0;
40-
executorch::aten::StridesType strides[kTensorDimensionLimit];
41-
SizesArrayRef sizes = self.sizes();
42-
DimOrderArrayRef dim_order = self.dim_order();
43-
44-
dim_order_to_stride_nocheck(
45-
sizes.data(), dim_order.data(), sizes.size(), strides);
46-
for (const auto i : c10::irange(self.dim())) {
47-
index += cur_indices[i] * strides[i];
48-
}
49-
return index;
50-
}
51-
5234
template <typename SELF_CTYPE, typename OUT_CTYPE>
5335
void _to_dim_order_copy_impl(const Tensor& self, Tensor& out) {
5436
auto self_data = self.mutable_data_ptr<SELF_CTYPE>();
5537
auto out_data = out.mutable_data_ptr<OUT_CTYPE>();
5638

57-
size_t coordinate[kTensorDimensionLimit] = {0};
58-
59-
// Copy data from self to out index by index. Same index in self and out
60-
// should have same value, no matter the order of dimensions.
61-
for (ssize_t i = 0; i < self.numel(); i++) {
62-
// Update the current indices.
63-
for (ssize_t j = self.dim() - 1; j >= 0; j--) {
64-
if (coordinate[j] + 1 < static_cast<size_t>(self.size(j))) {
65-
coordinate[j]++;
66-
break;
67-
} else {
68-
coordinate[j] = 0;
69-
}
70-
}
71-
// Get the corresponding index of self_data and out_data by stride.
72-
int64_t self_data_index = coordinateToIndexWithDimOrder(self, coordinate);
73-
int64_t out_data_index = coordinateToIndexWithDimOrder(out, coordinate);
74-
39+
// Here we make a slightly off-label use of
40+
// BroadcastIndexesRange. It always assumes it doesn't have to care
41+
// about different dim_order between input and output, but we can
42+
// just force it to respect strides (and thus dim_order) for its
43+
// inputs using support_noncontiguous_input_tensors=true, and then pretend
44+
// the output is just another input.
45+
for (const auto [unused_index, self_data_index, out_data_index] :
46+
BroadcastIndexesRange<2, /*support_noncontiguous_input_tensors=*/true>(
47+
/*dummy output*/ self, self, out)) {
48+
(void)unused_index;
7549
out_data[out_data_index] =
7650
static_cast<OUT_CTYPE>(self_data[self_data_index]);
7751
}

kernels/portable/cpu/op_glu.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ Tensor& glu_out_tensor(
110110
split_input.second_half,
111111
utils::SupportedTensorDtypes::FLOATHBF16,
112112
out,
113-
utils::internal::SupportNoncontiguousTensors());
113+
utils::internal::SupportNoncontiguousInputTensors());
114114
});
115115
return out;
116116
}

kernels/portable/cpu/util/broadcast_indexes_range.h

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,9 @@ inline bool sizes_match_ignoring_leading_1s(
4343
std::equal(lhs_begin, lhs_end, rhs_begin);
4444
}
4545

46-
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
46+
template <
47+
std::size_t kNumInputs,
48+
bool support_noncontiguous_input_tensors = false>
4749
class BroadcastIndexesIterator {
4850
public:
4951
using difference_type = ssize_t;
@@ -57,7 +59,7 @@ class BroadcastIndexesIterator {
5759
template <typename... Args>
5860
explicit BroadcastIndexesIterator(const Tensor& output, const Args&... args)
5961
: output_dim_or_zero_if_no_broadcasting_(
60-
!support_noncontiguous_tensors &&
62+
!support_noncontiguous_input_tensors &&
6163
(sizes_match_ignoring_leading_1s(
6264
args.sizes(),
6365
output.sizes()) &&
@@ -69,7 +71,7 @@ class BroadcastIndexesIterator {
6971
sizeof...(args) == kNumInputs && (std::is_same_v<Args, Tensor> && ...),
7072
"BroadcastIndexesIterator constructor requires kNumInputs input tensor"
7173
"arguments!");
72-
if (support_noncontiguous_tensors ||
74+
if (support_noncontiguous_input_tensors ||
7375
output_dim_or_zero_if_no_broadcasting_ != 0) {
7476
effective_input_broadcast_strides_ = {
7577
effective_input_broadcast_stride(output, args)...};
@@ -254,16 +256,21 @@ class BroadcastIndexesIterator {
254256
* linearize_access_indexes(), BroadcastIndexesRange avoids expensive
255257
* division and modulo operations on each iteration.
256258
*
257-
* The support_noncontiguous_tensors argument disables an optimization
258-
* that causes the iterators not to respect strides in some
259-
* cases. This optimization is normally safe because ExecuTorch
260-
* tensors are contiguous.
259+
* The support_noncontiguous_input_tensors argument disables an
260+
* optimization that causes the iterators not to respect strides in
261+
* some cases for input tensors. This optimization is normally safe
262+
* because ExecuTorch tensors are contiguous. Non-contiguous output
263+
* tensors are currently never supported (but note that this can be
264+
* worked around by ignoring the output index and providing the true
265+
* output as an extra input).
261266
*/
262-
template <std::size_t kNumInputs, bool support_noncontiguous_tensors = false>
267+
template <
268+
std::size_t kNumInputs,
269+
bool support_noncontiguous_input_tensors = false>
263270
class BroadcastIndexesRange {
264271
public:
265272
using iterator = internal::
266-
BroadcastIndexesIterator<kNumInputs, support_noncontiguous_tensors>;
273+
BroadcastIndexesIterator<kNumInputs, support_noncontiguous_input_tensors>;
267274

268275
template <typename... Args>
269276
BroadcastIndexesRange(const Tensor& output, const Args&... args)

kernels/portable/cpu/util/elementwise_util.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ namespace internal {
5656
* strides; normally, this is not strictly necessary because ExecuTorch
5757
* Tensors are contiguous.
5858
*/
59-
struct SupportNoncontiguousTensors {
60-
explicit SupportNoncontiguousTensors() = default;
59+
struct SupportNoncontiguousInputTensors {
60+
explicit SupportNoncontiguousInputTensors() = default;
6161
};
6262

6363
template <
@@ -292,7 +292,7 @@ inline void apply_unitensor_elementwise_fn(
292292
const Tensor& a,
293293
SupportedTensorDtypes a_dtypes,
294294
const Tensor& out,
295-
SupportNoncontiguousTensors) {
295+
SupportNoncontiguousInputTensors) {
296296
internal::apply_elementwise_fn<
297297
CTYPE_COMPUTE,
298298
op_name,
@@ -366,7 +366,7 @@ inline void apply_bitensor_elementwise_fn(
366366
const Tensor& b,
367367
SupportedTensorDtypes b_dtypes,
368368
const Tensor& out,
369-
SupportNoncontiguousTensors) {
369+
SupportNoncontiguousInputTensors) {
370370
internal::apply_elementwise_fn<
371371
CTYPE_COMPUTE,
372372
op_name,
@@ -467,7 +467,7 @@ inline void apply_tritensor_elementwise_fn(
467467
const Tensor& c,
468468
SupportedTensorDtypes c_dtypes,
469469
const Tensor& out,
470-
SupportNoncontiguousTensors) {
470+
SupportNoncontiguousInputTensors) {
471471
internal::apply_elementwise_fn<
472472
CTYPE_COMPUTE,
473473
op_name,

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1324,6 +1324,7 @@ ATEN_OPS = (
13241324
name = "op__to_dim_order_copy",
13251325
deps = [
13261326
":scalar_utils",
1327+
"//executorch/kernels/portable/cpu/util:broadcast_util",
13271328
"//executorch/kernels/portable/cpu/util:copy_ops_util",
13281329
],
13291330
),

0 commit comments

Comments
 (0)