-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][bufferization] Enable moving dependent values in eliminate-empty-tensors #169718
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir-bufferization Author: Quinn Dawkins (qedawkins) ChangesCurrently empty tensor elimination by constructing a SubsetExtractionOp to match a SubsetInsertionOp at the end of a DPS chain will fail if any operands required by the insertion op don't dominate the insertion point for the extraction op. This change improves the transformation by attempting to move all pure producers of required operands to the insertion point of the extraction op. In the process this improves a number of tests for empty tensor elimination. Full diff: https://github.com/llvm/llvm-project/pull/169718.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 2ed96afbace81..6a0c94b06c6b2 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -85,11 +85,15 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
/// only for movement of definitions within the same basic block. Note that this
/// is an all-or-nothing approach. Either definitions of all values are moved
/// before insertion point, or none of them are.
+/// If `ignoreSideEffect` is set, this will allow movement of all dependent
+/// producers regardless of whether they are side effecting.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
Operation *insertionPoint,
- DominanceInfo &dominance);
+ DominanceInfo &dominance,
+ bool ignoreSideEffects = true);
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
- Operation *insertionPoint);
+ Operation *insertionPoint,
+ bool ignoreSideEffects = true);
/// Run a set of structural simplifications over the given regions. This
/// includes transformations like unreachable block elimination, dead argument
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 1784964cf9b95..0843b4398b24f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Transforms/RegionUtils.h"
namespace mlir {
namespace bufferization {
@@ -105,8 +106,15 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
// this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- return {};
+ if (!insertionPoint) {
+ // If no already suitable insertion point was found, attempt to move all
+ // needed values before the user.
+ if (failed(moveValueDefinitions(rewriter, neededValues, user,
+ /*ignoreSideEffects=*/false))) {
+ return {};
+ }
+ insertionPoint = user;
+ }
rewriter.setInsertionPoint(insertionPoint);
Value replacement =
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 31ae1d1895b81..390fc76cc6533 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1145,7 +1145,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
ValueRange values,
Operation *insertionPoint,
- DominanceInfo &dominance) {
+ DominanceInfo &dominance,
+ bool ignoreSideEffects) {
// Remove the values that already dominate the insertion point.
SmallVector<Value> prunedValues;
for (auto value : values) {
@@ -1178,8 +1179,14 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
options.omitBlockArguments = true;
+ bool dependsOnSideEffectingOp = false;
options.filter = [&](Operation *sliceBoundaryOp) {
- return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+ bool mustMove =
+ !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+ if (mustMove && !isPure(sliceBoundaryOp)) {
+ dependsOnSideEffectingOp = true;
+ }
+ return mustMove;
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
@@ -1188,6 +1195,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
(void)result;
}
+ // Check if any operation in the slice is side-effecting.
+ if (!ignoreSideEffects && dependsOnSideEffectingOp)
+ return failure();
+
// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
@@ -1206,7 +1217,9 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
ValueRange values,
- Operation *insertionPoint) {
+ Operation *insertionPoint,
+ bool ignoreSideEffects) {
DominanceInfo dominance(insertionPoint);
- return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
+ return moveValueDefinitions(rewriter, values, insertionPoint, dominance,
+ ignoreSideEffects);
}
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 8249d59b2374e..3929f5be3b4ef 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
// -----
-// `EmptyTensorElimination` fails to find a valid insertion
-// point for the new injected `SubsetExtraction`.
-// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
-func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
+// CHECK-LABEL: func.func @eliminate_all_empty_tensors
+func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
- // CHECK: memref.alloc
- // CHECK: memref.alloc
- // CHECK: memref.alloc
+ // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+ // CHECK-NOT: memref.alloc
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
- // CHECK: memref.copy
+ // CHECK-NOT: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
@@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
// -----
-// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
-func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
+// CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors
+func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
- // CHECK: memref.alloc
// CHECK-NOT: memref.alloc
- %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+ %concatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
- // CHECK: memref.copy
- %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+ // CHECK-NOT: memref.copy
+ %inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
@@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
// CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty
// CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty
+// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+// CHECK-NOT: memref.alloc
+// CHECK-NOT: memref.copy
+// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0]
+// CHECK-ELIM: linalg.fill
+// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64]
+// CHECK-ELIM: linalg.fill
func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
- // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
- // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
- // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
- // CHECK: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
- // CHECK-NOT: memref.copy
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1 : tensor<5x6x128xf32>
}
+
+// -----
+
+// Test that dependent pure operations are moved before the
+// insertion point to enable empty tensor elimination.
+
+// CHECK-LABEL: func.func @move_dependent_arith_op(
+// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>
+// CHECK-SAME: %[[ARG1:.*]]: index
+// CHECK-NOT: memref.alloc
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
+// CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1]
+// CHECK: linalg.fill {{.*}} outs(%[[SV]]
+// CHECK: return %[[ARG0]]
+// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op(
+// CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32>
+// CHECK-ELIM-SAME: %[[ARG1:.*]]: index
+// CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index
+// CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
+// CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1]
+// CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]]
+// CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]]
+func.func @move_dependent_arith_op(
+ %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
+ %arg1: index, %f: f32) -> tensor<10xf32>
+{
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ %c5 = arith.constant 5 : index
+ %offset = arith.addi %arg1, %c5 : index
+ %2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
+ : tensor<5xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+}
+
+// -----
+
+// Test that side-effecting operations are not moved, preventing empty
+// tensor elimination.
+
+// CHECK-LABEL: func.func @side_effecting_op_blocks_movement(
+// CHECK: memref.alloc
+// CHECK: linalg.fill
+// CHECK: memref.load
+// CHECK: memref.subview
+// CHECK: memref.copy
+// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement(
+// CHECK-ELIM: tensor.empty
+// CHECK-ELIM: linalg.fill
+// CHECK-ELIM: memref.load
+// CHECK-ELIM: tensor.insert_slice
+func.func @side_effecting_op_blocks_movement(
+ %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
+ %mem: memref<index>, %f: f32) -> tensor<10xf32>
+{
+ %0 = tensor.empty() : tensor<5xf32>
+ %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+ %offset = memref.load %mem[] : memref<index>
+ %2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
+ : tensor<5xf32> into tensor<10xf32>
+ return %2 : tensor<10xf32>
+}
|
MaheshRavishankar
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you have a use case for ignoreSideEffects = true. It is probably better to have ignoreSideEffects = false by default (or just not change the API and make sure that we always account for side-effecting operations).
I did not, it was just an attempt to keep the API NFC. I can happily drop the flag and default it to true then. |
mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
Outdated
Show resolved
Hide resolved
…ty-tensors Currently empty tensor elimination by constructing a SubsetExtractionOp to match a SubsetInsertionOp at the end of a DPS chain will fail if any operands required by the insertion op don't dominate the insertion point for the extraction op. This change improves the transformation by attempting to move all pure producers of required operands to the insertion point of the extraction op. In the process this improves a number of tests for empty tensor elimination.
6f69946 to
894db9b
Compare
Currently empty tensor elimination by constructing a SubsetExtractionOp to match a SubsetInsertionOp at the end of a DPS chain will fail if any operands required by the insertion op don't dominate the insertion point for the extraction op.
This change improves the transformation by attempting to move all pure producers of required operands to the insertion point of the extraction op. In the process this improves a number of tests for empty tensor elimination.