@@ -321,13 +321,13 @@ llvm::SmallVector<InterpreterValue> Gather(
321
321
auto & operand = args[0 ];
322
322
auto & start_indices = args[1 ];
323
323
const auto & operand_view = operand.View ();
324
- int64_t operand_rank = operand_view.Rank ();
324
+ int64_t operand_rank = operand_view.num_dimensions ();
325
325
326
326
// Make a fake BufferView for the start indices.
327
327
BufferView start_indices_view = start_indices.View ();
328
- auto output_rank =
329
- static_cast < int64_t >(start_indices_view. Rank () + offset_dims.size ());
330
- if (index_vector_dim < start_indices_view.Rank ()) {
328
+ auto output_rank = static_cast < int64_t >(start_indices_view. num_dimensions () +
329
+ offset_dims.size ());
330
+ if (index_vector_dim < start_indices_view.num_dimensions ()) {
331
331
--output_rank;
332
332
start_indices_view.sizes [index_vector_dim] = 1 ;
333
333
}
@@ -365,7 +365,7 @@ llvm::SmallVector<InterpreterValue> Gather(
365
365
for (auto start_indices_index : start_indices_view.Indices ()) {
366
366
SmallVector<int64_t > operand_base_indices (operand_rank);
367
367
for (auto [i, dim] : llvm::enumerate (start_index_map)) {
368
- if (index_vector_dim < start_indices_view.Rank ()) {
368
+ if (index_vector_dim < start_indices_view.num_dimensions ()) {
369
369
start_indices_index[index_vector_dim] = static_cast <int64_t >(i);
370
370
}
371
371
operand_base_indices[dim] = std::max<int64_t >(
@@ -410,9 +410,9 @@ llvm::SmallVector<InterpreterValue> Scatter(
410
410
auto update_window_dims = dims.getUpdateWindowDims ();
411
411
412
412
auto input_view = n_inputs.front ().View ();
413
- int64_t operand_rank = input_view.Rank ();
414
- int64_t updates_rank = n_updates.front ().View ().Rank ();
415
- int64_t indices_rank = scatter_indices.View ().Rank ();
413
+ int64_t operand_rank = input_view.num_dimensions ();
414
+ int64_t updates_rank = n_updates.front ().View ().num_dimensions ();
415
+ int64_t indices_rank = scatter_indices.View ().num_dimensions ();
416
416
417
417
llvm::SmallVector<int64_t > batch_dims;
418
418
for (int64_t dim = 0 ; dim < operand_rank; ++dim) {
@@ -689,23 +689,23 @@ InterpreterValue DotGeneralImpl(InterpreterValue& lhs, InterpreterValue& rhs,
689
689
for (int64_t lhs_dim : lhs_batch) {
690
690
dimensions.push_back (lhsv.sizes [lhs_dim]);
691
691
}
692
- for (int64_t i = 0 ; i < lhsv.Rank (); i++) {
692
+ for (int64_t i = 0 ; i < lhsv.num_dimensions (); i++) {
693
693
if (!llvm::is_contained (lhs_contracting, i) &&
694
694
!llvm::is_contained (lhs_batch, i)) {
695
695
dimensions.push_back (lhsv.sizes [i]);
696
696
lhs_non_batch.push_back (i);
697
697
}
698
698
}
699
- for (int64_t i = 0 ; i < rhs.View ().Rank (); i++) {
699
+ for (int64_t i = 0 ; i < rhs.View ().num_dimensions (); i++) {
700
700
if (!llvm::is_contained (rhs_contracting, i) &&
701
701
!llvm::is_contained (rhs_batch, i)) {
702
702
dimensions.push_back (rhsv.sizes [i]);
703
703
rhs_non_batch.push_back (i);
704
704
}
705
705
}
706
706
707
- SmallVector<int64_t > lhs_index (lhsv.Rank ());
708
- SmallVector<int64_t > rhs_index (rhsv.Rank ());
707
+ SmallVector<int64_t > lhs_index (lhsv.num_dimensions ());
708
+ SmallVector<int64_t > rhs_index (rhsv.num_dimensions ());
709
709
SmallVector<int64_t > output_index (dimensions.size ());
710
710
auto output = lhs.TypedAlike (dimensions);
711
711
@@ -778,7 +778,7 @@ InterpreterValue Dot(InterpreterState& state, mhlo::DotOp op,
778
778
auto ty = cast<ShapedType>(op->getResultTypes ()[0 ]);
779
779
auto result = lhs.TypedAlike (ty.getShape ());
780
780
781
- if (lhs.View ().Rank () == 1 && rhs.View ().Rank () == 1 ) {
781
+ if (lhs.View ().num_dimensions () == 1 && rhs.View ().num_dimensions () == 1 ) {
782
782
DispatchScalarType (ty, [&](auto dummy) {
783
783
using T = decltype (dummy);
784
784
using TT = TensorOrMemref<T>;
@@ -791,7 +791,8 @@ InterpreterValue Dot(InterpreterState& state, mhlo::DotOp op,
791
791
}
792
792
std::get<TT>(result.storage ).at ({}) += product;
793
793
});
794
- } else if (lhs.View ().Rank () == 2 && rhs.View ().Rank () == 1 ) {
794
+ } else if (lhs.View ().num_dimensions () == 2 &&
795
+ rhs.View ().num_dimensions () == 1 ) {
795
796
DispatchScalarType (ty, [&](auto dummy) {
796
797
using TT = TensorOrMemref<decltype (dummy)>;
797
798
auto lhs_tensor = std::get<TT>(lhs.storage );
@@ -803,7 +804,8 @@ InterpreterValue Dot(InterpreterState& state, mhlo::DotOp op,
803
804
}
804
805
}
805
806
});
806
- } else if (lhs.View ().Rank () == 2 && rhs.View ().Rank () == 2 ) {
807
+ } else if (lhs.View ().num_dimensions () == 2 &&
808
+ rhs.View ().num_dimensions () == 2 ) {
807
809
DispatchScalarType (ty, [&](auto dummy) {
808
810
using TT = TensorOrMemref<decltype (dummy)>;
809
811
auto lhs_tensor = std::get<TT>(lhs.storage );
0 commit comments