Skip to content

Commit 37ffbbb

Browse files
author
Peiming Liu
authored
[mlir][tensor][sparse] don't drop encoding when infer result type (llvm#91817)
A general question is: is it possible to support hooks here to infer the encoding? E.g., when the extracted tensor slice is rank-reduced, the encoding need to be updated accordingly as well.
1 parent 6140b5b commit 37ffbbb

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -2020,7 +2020,8 @@ RankedTensorType ExtractSliceOp::inferResultType(
20202020
assert(static_cast<int64_t>(staticSizes.size()) ==
20212021
sourceTensorType.getRank() &&
20222022
"unexpected staticSizes not equal to rank of source");
2023-
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType());
2023+
return RankedTensorType::get(staticSizes, sourceTensorType.getElementType(),
2024+
sourceTensorType.getEncoding());
20242025
}
20252026

20262027
RankedTensorType ExtractSliceOp::inferResultType(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// RUN: mlir-opt %s -split-input-file -canonicalize="test-convergence" | FileCheck %s
2+
3+
#BCOO = #sparse_tensor.encoding<{
4+
map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton)
5+
}>
6+
7+
// CHECK-DAG: #[[$BCOO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : loose_compressed(nonunique), d2 : singleton) }>
8+
// CHECK-LABEL: func @sparse_slice_canonicalize
9+
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x?x?xf32, #[[$BCOO]]>
10+
// CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[ARG0]][0, %{{[a-zA-Z0-9_]+}}, 1]
11+
// CHECK-SAME: [4, 1, %{{[a-zA-Z0-9_]+}}] [1, 1, 1]
12+
// CHECK-SAME: : tensor<?x?x?xf32, #[[$BCOO]]> to tensor<4x1x?xf32, #[[$BCOO]]>
13+
// CHECK: %[[RESULT:.+]] = tensor.cast %[[SLICE]]
14+
// CHECK: return %[[RESULT]]
15+
func.func @sparse_slice_canonicalize(%arg0 : tensor<?x?x?xf32, #BCOO>, %arg1 : index,
16+
%arg2 : index) -> tensor<?x?x?xf32, #BCOO>
17+
{
18+
%c0 = arith.constant 0 : index
19+
%c1 = arith.constant 1 : index
20+
%c4 = arith.constant 4 : index
21+
%0 = tensor.extract_slice %arg0[%c0, %arg1, %c1] [%c4, %c1, %arg2] [%c1, %c1, %c1] : tensor<?x?x?xf32, #BCOO> to tensor<?x?x?xf32, #BCOO>
22+
return %0 : tensor<?x?x?xf32, #BCOO>
23+
}

0 commit comments

Comments
 (0)