Skip to content

Commit 2dec883

Browse files
Rename some XLA internal API from rank() to num_dimensions() to avoid confusion.
PiperOrigin-RevId: 738831506
1 parent bcd0166 commit 2dec883

16 files changed

+84
-75
lines changed

xla/backends/gpu/codegen/triton/tma_utils_test.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ TEST(Create2DTmaDescriptorTest, ValidInputReturnCorrectDescriptor) {
7070
TmaDescriptor tma_desc,
7171
Create2DTmaDescriptor(global_shape, block_shape, element_type));
7272
EXPECT_EQ(tma_desc.element_size(), 4);
73-
EXPECT_EQ(tma_desc.rank(), 2);
73+
EXPECT_EQ(tma_desc.num_dimensions(), 2);
7474
EXPECT_THAT(tma_desc.global_dims(), ElementsAre(128, 256));
7575
EXPECT_THAT(tma_desc.global_strides(), ElementsAre(128 * 4));
7676
EXPECT_THAT(tma_desc.box_dims(), ElementsAre(32, 64));

xla/mlir/tools/mlir_interpreter/dialects/linalg.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ llvm::SmallVector<InterpreterValue> Map(InterpreterState& state,
149149
isa<TensorType>(op.getInit().getType()) ? init.Clone() : init;
150150

151151
InterpreterScope scope(state);
152-
SmallVector<int64_t> ivs(output.View().Rank());
152+
SmallVector<int64_t> ivs(output.View().num_dimensions());
153153
scope.SetSideChannel(std::make_shared<IterationIndexSideChannel>(ivs));
154154
for (const auto& indices : output.View().Indices()) {
155155
std::copy(indices.begin(), indices.end(), ivs.begin());

xla/mlir/tools/mlir_interpreter/dialects/memref.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ InterpreterValue Subview(InterpreterState& state, memref::SubViewOp subview,
129129
return {};
130130
}
131131

132-
if (subview.getResult().getType().getRank() == out_view.Rank()) {
132+
if (subview.getResult().getType().getRank() == out_view.num_dimensions()) {
133133
return out;
134134
}
135135

136136
auto shape = subview.getResult().getType().getShape();
137137
// TODO(jreiffers): Check why subview.getDroppedDims() yields the wrong shape
138138
// here for 1x2x2x3 (-> 1x2x1x3) -> 1x2x3 (claiming 0 is dropped).
139139
int64_t dim = 0;
140-
while (dim < out_view.Rank() && dim < shape.size()) {
140+
while (dim < out_view.num_dimensions() && dim < shape.size()) {
141141
if (shape[dim] != 1 && out_view.sizes[dim] == 1) {
142142
out_view.sizes.erase(out_view.sizes.begin() + dim);
143143
out_view.strides.erase(out_view.strides.begin() + dim);
@@ -147,7 +147,7 @@ InterpreterValue Subview(InterpreterState& state, memref::SubViewOp subview,
147147
++dim;
148148
}
149149
}
150-
while (dim < out_view.Rank()) {
150+
while (dim < out_view.num_dimensions()) {
151151
assert(out_view.sizes.back() == 1 && "expected remaining dims to be 1");
152152
out_view.sizes.pop_back();
153153
out_view.strides.pop_back();

xla/mlir/tools/mlir_interpreter/dialects/mhlo.cc

+17-15
Original file line numberDiff line numberDiff line change
@@ -321,13 +321,13 @@ llvm::SmallVector<InterpreterValue> Gather(
321321
auto& operand = args[0];
322322
auto& start_indices = args[1];
323323
const auto& operand_view = operand.View();
324-
int64_t operand_rank = operand_view.Rank();
324+
int64_t operand_rank = operand_view.num_dimensions();
325325

326326
// Make a fake BufferView for the start indices.
327327
BufferView start_indices_view = start_indices.View();
328-
auto output_rank =
329-
static_cast<int64_t>(start_indices_view.Rank() + offset_dims.size());
330-
if (index_vector_dim < start_indices_view.Rank()) {
328+
auto output_rank = static_cast<int64_t>(start_indices_view.num_dimensions() +
329+
offset_dims.size());
330+
if (index_vector_dim < start_indices_view.num_dimensions()) {
331331
--output_rank;
332332
start_indices_view.sizes[index_vector_dim] = 1;
333333
}
@@ -365,7 +365,7 @@ llvm::SmallVector<InterpreterValue> Gather(
365365
for (auto start_indices_index : start_indices_view.Indices()) {
366366
SmallVector<int64_t> operand_base_indices(operand_rank);
367367
for (auto [i, dim] : llvm::enumerate(start_index_map)) {
368-
if (index_vector_dim < start_indices_view.Rank()) {
368+
if (index_vector_dim < start_indices_view.num_dimensions()) {
369369
start_indices_index[index_vector_dim] = static_cast<int64_t>(i);
370370
}
371371
operand_base_indices[dim] = std::max<int64_t>(
@@ -410,9 +410,9 @@ llvm::SmallVector<InterpreterValue> Scatter(
410410
auto update_window_dims = dims.getUpdateWindowDims();
411411

412412
auto input_view = n_inputs.front().View();
413-
int64_t operand_rank = input_view.Rank();
414-
int64_t updates_rank = n_updates.front().View().Rank();
415-
int64_t indices_rank = scatter_indices.View().Rank();
413+
int64_t operand_rank = input_view.num_dimensions();
414+
int64_t updates_rank = n_updates.front().View().num_dimensions();
415+
int64_t indices_rank = scatter_indices.View().num_dimensions();
416416

417417
llvm::SmallVector<int64_t> batch_dims;
418418
for (int64_t dim = 0; dim < operand_rank; ++dim) {
@@ -689,23 +689,23 @@ InterpreterValue DotGeneralImpl(InterpreterValue& lhs, InterpreterValue& rhs,
689689
for (int64_t lhs_dim : lhs_batch) {
690690
dimensions.push_back(lhsv.sizes[lhs_dim]);
691691
}
692-
for (int64_t i = 0; i < lhsv.Rank(); i++) {
692+
for (int64_t i = 0; i < lhsv.num_dimensions(); i++) {
693693
if (!llvm::is_contained(lhs_contracting, i) &&
694694
!llvm::is_contained(lhs_batch, i)) {
695695
dimensions.push_back(lhsv.sizes[i]);
696696
lhs_non_batch.push_back(i);
697697
}
698698
}
699-
for (int64_t i = 0; i < rhs.View().Rank(); i++) {
699+
for (int64_t i = 0; i < rhs.View().num_dimensions(); i++) {
700700
if (!llvm::is_contained(rhs_contracting, i) &&
701701
!llvm::is_contained(rhs_batch, i)) {
702702
dimensions.push_back(rhsv.sizes[i]);
703703
rhs_non_batch.push_back(i);
704704
}
705705
}
706706

707-
SmallVector<int64_t> lhs_index(lhsv.Rank());
708-
SmallVector<int64_t> rhs_index(rhsv.Rank());
707+
SmallVector<int64_t> lhs_index(lhsv.num_dimensions());
708+
SmallVector<int64_t> rhs_index(rhsv.num_dimensions());
709709
SmallVector<int64_t> output_index(dimensions.size());
710710
auto output = lhs.TypedAlike(dimensions);
711711

@@ -778,7 +778,7 @@ InterpreterValue Dot(InterpreterState& state, mhlo::DotOp op,
778778
auto ty = cast<ShapedType>(op->getResultTypes()[0]);
779779
auto result = lhs.TypedAlike(ty.getShape());
780780

781-
if (lhs.View().Rank() == 1 && rhs.View().Rank() == 1) {
781+
if (lhs.View().num_dimensions() == 1 && rhs.View().num_dimensions() == 1) {
782782
DispatchScalarType(ty, [&](auto dummy) {
783783
using T = decltype(dummy);
784784
using TT = TensorOrMemref<T>;
@@ -791,7 +791,8 @@ InterpreterValue Dot(InterpreterState& state, mhlo::DotOp op,
791791
}
792792
std::get<TT>(result.storage).at({}) += product;
793793
});
794-
} else if (lhs.View().Rank() == 2 && rhs.View().Rank() == 1) {
794+
} else if (lhs.View().num_dimensions() == 2 &&
795+
rhs.View().num_dimensions() == 1) {
795796
DispatchScalarType(ty, [&](auto dummy) {
796797
using TT = TensorOrMemref<decltype(dummy)>;
797798
auto lhs_tensor = std::get<TT>(lhs.storage);
@@ -803,7 +804,8 @@ InterpreterValue Dot(InterpreterState& state, mhlo::DotOp op,
803804
}
804805
}
805806
});
806-
} else if (lhs.View().Rank() == 2 && rhs.View().Rank() == 2) {
807+
} else if (lhs.View().num_dimensions() == 2 &&
808+
rhs.View().num_dimensions() == 2) {
807809
DispatchScalarType(ty, [&](auto dummy) {
808810
using TT = TensorOrMemref<decltype(dummy)>;
809811
auto lhs_tensor = std::get<TT>(lhs.storage);

xla/mlir/tools/mlir_interpreter/dialects/tensor.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ llvm::SmallVector<InterpreterValue> ExtractSlice(
126126
int64_t dim = 0;
127127
const auto& result_sizes = extract.getResultType().getShape();
128128
const auto& static_sizes = extract.getStaticSizes();
129-
while (dim < out_view.Rank()) {
129+
while (dim < out_view.num_dimensions()) {
130130
if (static_sizes[num_dropped + dim] == 1 &&
131131
(dim >= result_sizes.size() || result_sizes[dim] != 1)) {
132132
out_view.sizes.erase(out_view.sizes.begin() + dim);

xla/mlir/tools/mlir_interpreter/dialects/util.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ InterpreterValue TransposeImpl(const InterpreterValue& in,
121121

122122
int64_t DimImpl(const InterpreterValue& in, int64_t index,
123123
InterpreterState& state) {
124-
if (index < 0 || index >= in.View().Rank()) {
124+
if (index < 0 || index >= in.View().num_dimensions()) {
125125
state.AddFailure("dimension index out of bounds");
126126
return 0;
127127
}

xla/mlir/tools/mlir_interpreter/dialects/vector.cc

+7-6
Original file line numberDiff line numberDiff line change
@@ -299,7 +299,7 @@ InterpreterValue Extract(InterpreterState& state, vector::ExtractOp extract,
299299
for (int64_t offset : extract.getStaticPosition()) {
300300
state.CheckSuccess(result_view.Slice(0, offset), "index out of bounds");
301301
}
302-
return result_view.Rank() == 0 ? result.ExtractElement({}) : result;
302+
return result_view.num_dimensions() == 0 ? result.ExtractElement({}) : result;
303303
}
304304

305305
InterpreterValue ExtractElement(InterpreterState& state,
@@ -635,7 +635,7 @@ InterpreterValue Shuffle(InterpreterState& state, vector::ShuffleOp shuffle,
635635
result_view.is_vector = true;
636636

637637
auto mask = shuffle.getMask();
638-
bool is_zero_dim = v0.View().Rank() == 0;
638+
bool is_zero_dim = v0.View().num_dimensions() == 0;
639639
int64_t size0 = is_zero_dim ? 1 : v0.View().sizes[0];
640640
for (auto [dst_index, src_index] : llvm::enumerate(mask)) {
641641
auto src = src_index < size0 ? v0 : v1;
@@ -698,9 +698,9 @@ std::optional<InterpreterValue> ExtractMemorySlice(
698698
auto mem_slice = memory;
699699
auto& mem_slice_view = mem_slice.View();
700700
auto& vector_view = vector.View();
701-
for (int64_t i = 0; i < mem_slice_view.Rank(); ++i) {
701+
for (int64_t i = 0; i < mem_slice_view.num_dimensions(); ++i) {
702702
bool found = false;
703-
for (int64_t j = 0; !found && j < vector_view.Rank(); ++j) {
703+
for (int64_t j = 0; !found && j < vector_view.num_dimensions(); ++j) {
704704
if (map.getResult(j).isFunctionOfDim(i)) {
705705
int64_t size = mem_slice_view.sizes[i] - offsets[i];
706706
bool is_in_bounds = size >= vector_view.sizes[j];
@@ -801,7 +801,8 @@ llvm::SmallVector<InterpreterValue> TransferWrite(
801801
}
802802

803803
const auto& src_view = src.View();
804-
assert(transfer.getPermutationMap().getNumResults() == src_view.Rank() &&
804+
assert(transfer.getPermutationMap().getNumResults() ==
805+
src_view.num_dimensions() &&
805806
"expected matching number of results");
806807

807808
dst = transfer.getSource().getType().isa<TensorType>() ? dst.Clone() : dst;
@@ -832,7 +833,7 @@ InterpreterValue Transpose(InterpreterState&, vector::TransposeOp transpose,
832833

833834
InterpreterValue TypeCast(InterpreterState&, vector::TypeCastOp,
834835
InterpreterValue vector) {
835-
vector.View().num_vector_dims = vector.View().Rank();
836+
vector.View().num_vector_dims = vector.View().num_dimensions();
836837
return vector;
837838
}
838839

xla/mlir/tools/mlir_interpreter/framework/interpreter_value.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct InterpreterValuePrinter {
8383
} else {
8484
os << TypeStr::Get(T{}) << ">: ";
8585
}
86-
SmallVector<int64_t> indices(t.view.Rank() +
86+
SmallVector<int64_t> indices(t.view.num_dimensions() +
8787
t.view.num_vector_dims.value_or(0));
8888
std::function<void(int64_t)> print;
8989
print = [&](int64_t dim) {

xla/mlir/tools/mlir_interpreter/framework/tensor_or_memref.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -77,22 +77,22 @@ SmallVector<int64_t> BufferView::GetStridesForLayout(ArrayRef<int64_t> sizes,
7777
}
7878

7979
LogicalResult BufferView::Slice(int64_t dim_index, int64_t dim_offset) {
80-
llvm::SmallVector<int64_t> offsets(Rank(), 0);
80+
llvm::SmallVector<int64_t> offsets(num_dimensions(), 0);
8181
offsets[dim_index] = dim_offset;
8282
if (auto new_offset = GetPhysicalIndex(offsets)) {
8383
offset = *new_offset;
8484
} else {
8585
return failure();
8686
}
87-
if (dim_index >= Rank()) --*num_vector_dims;
87+
if (dim_index >= num_dimensions()) --*num_vector_dims;
8888
strides.erase(strides.begin() + dim_index);
8989
sizes.erase(sizes.begin() + dim_index);
9090
return success();
9191
}
9292

9393
LogicalResult BufferView::Slice(int64_t dim_index, int64_t dim_offset,
9494
int64_t dim_size, int64_t dim_stride) {
95-
llvm::SmallVector<int64_t> offsets(Rank(), 0);
95+
llvm::SmallVector<int64_t> offsets(num_dimensions(), 0);
9696
offsets[dim_index] = dim_offset;
9797
if (dim_size == 0) {
9898
offset = 0;

xla/mlir/tools/mlir_interpreter/framework/tensor_or_memref.h

+8-4
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,9 @@ struct BufferView {
7373
std::optional<int64_t> num_vector_dims = std::nullopt;
7474
bool is_vector = false;
7575

76-
int64_t Rank() const { return sizes.size() - num_vector_dims.value_or(0); }
76+
int64_t num_dimensions() const {
77+
return sizes.size() - num_vector_dims.value_or(0);
78+
}
7779

7880
// Removes the dimension from the view. If you need to keep it, use the
7981
// overload below with dim_size = 1.
@@ -154,7 +156,7 @@ struct BufferView {
154156
return {
155157
view_,
156158
llvm::SmallVector<int64_t>(
157-
view_->Rank() +
159+
view_->num_dimensions() +
158160
(include_vector_dims_ ? view_->num_vector_dims.value_or(0) : 0)),
159161
include_vector_dims_};
160162
}
@@ -314,8 +316,10 @@ struct TensorOrMemref {
314316
TensorOrMemref VectorAt(ArrayRef<int64_t> indices) const {
315317
auto offset = view.GetPhysicalIndex(indices);
316318
BufferView subview;
317-
subview.strides = {view.strides.begin() + view.Rank(), view.strides.end()};
318-
subview.sizes = {view.sizes.begin() + view.Rank(), view.sizes.end()};
319+
subview.strides = {view.strides.begin() + view.num_dimensions(),
320+
view.strides.end()};
321+
subview.sizes = {view.sizes.begin() + view.num_dimensions(),
322+
view.sizes.end()};
319323
if (offset) {
320324
subview.offset = *offset;
321325
} else {

xla/mlir/tools/mlir_interpreter/framework/tests/tensor_or_memref_test.cc

+3-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ std::optional<int64_t> GetCollapsedStrideNaive(llvm::ArrayRef<int64_t> dims,
5656
// Find all physical indices for the dimensions.
5757
llvm::SmallBitVector v(view.GetNumElements());
5858
for (const auto& indices : f.Indices()) {
59-
SmallVector<int64_t> view_indices(view.Rank());
59+
SmallVector<int64_t> view_indices(view.num_dimensions());
6060
for (auto [dim, index] : llvm::zip(dims, indices)) {
6161
view_indices[dim] = index;
6262
}
@@ -83,9 +83,9 @@ TEST(TensorOrMemrefTest, CollapsedStride) {
8383
.strides = BufferView::GetDefaultStrides({1, 2, 3, 1, 5})};
8484

8585
auto check_all = [&]() {
86-
for (int64_t i = 0; i < (1 << view.Rank()); ++i) {
86+
for (int64_t i = 0; i < (1 << view.num_dimensions()); ++i) {
8787
SmallVector<int64_t> dims;
88-
for (int64_t dim = 0; dim < view.Rank(); ++dim) {
88+
for (int64_t dim = 0; dim < view.num_dimensions(); ++dim) {
8989
if (i & (1 << dim)) dims.push_back(dim);
9090
}
9191

0 commit comments

Comments
 (0)