Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mlir/include/mlir/Transforms/RegionUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
/// Move definitions of `values` before an insertion point. Current support is
/// only for movement of definitions within the same basic block. Note that this
/// is an all-or-nothing approach. Either definitions of all values are moved
/// before insertion point, or none of them are.
/// before insertion point, or none of them are. Any side-effecting operations
/// in the producer chain pessimistically blocks movement.
LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
Operation *insertionPoint,
DominanceInfo &dominance);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Interfaces/SubsetOpInterface.h"
#include "mlir/Transforms/RegionUtils.h"

namespace mlir {
namespace bufferization {
Expand Down Expand Up @@ -105,8 +106,13 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
// this replacement.
Operation *insertionPoint =
findValidInsertionPoint(emptyTensorOp, user, neededValues);
if (!insertionPoint)
return {};
if (!insertionPoint) {
// If no already suitable insertion point was found, attempt to move all
// needed values before the user.
if (failed(moveValueDefinitions(rewriter, neededValues, user)))
return {};
insertionPoint = user;
}

rewriter.setInsertionPoint(insertionPoint);
Value replacement =
Expand Down
17 changes: 12 additions & 5 deletions mlir/lib/Transforms/Utils/RegionUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1149,9 +1149,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Remove the values that already dominate the insertion point.
SmallVector<Value> prunedValues;
for (auto value : values) {
if (dominance.properlyDominates(value, insertionPoint)) {
if (dominance.properlyDominates(value, insertionPoint))
continue;
}
// Block arguments are not supported.
if (isa<BlockArgument>(value)) {
return rewriter.notifyMatchFailure(
Expand All @@ -1178,8 +1177,13 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Since current support is to only move within a same basic block,
// the slices dont need to look past block arguments.
options.omitBlockArguments = true;
bool dependsOnSideEffectingOp = false;
options.filter = [&](Operation *sliceBoundaryOp) {
return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
bool mustMove =
!dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
if (mustMove && !isPure(sliceBoundaryOp))
dependsOnSideEffectingOp = true;
return mustMove;
};
llvm::SetVector<Operation *> slice;
for (auto value : prunedValues) {
Expand All @@ -1188,6 +1192,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
(void)result;
}

// Check if any operation in the slice is side-effecting.
if (dependsOnSideEffectingOp)
return failure();

// If the slice contains `insertionPoint` cannot move the dependencies.
if (slice.contains(insertionPoint)) {
return rewriter.notifyMatchFailure(
Expand All @@ -1198,9 +1206,8 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
// Sort operations topologically before moving.
mlir::topologicalSort(slice);

for (Operation *op : slice) {
for (Operation *op : slice)
rewriter.moveOpBefore(op, insertionPoint);
}
return success();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32

// -----

// `EmptyTensorElimination` fails to find a valid insertion
// point for the new injected `SubsetExtraction`.
// CHECK-LABEL: func.func @fail_to_eliminate_any_empty_tensors
func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
// CHECK-LABEL: func.func @eliminate_all_empty_tensors
func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
// CHECK: memref.alloc
// CHECK: memref.alloc
// CHECK: memref.alloc
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
// CHECK-NOT: memref.alloc
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
// CHECK: memref.copy
// CHECK-NOT: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
Expand All @@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {

// -----

// CHECK-LABEL: func.func @succeed_to_eliminate_one_empty_tensor
func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
// CHECK-LABEL: func.func @eliminate_concatenated_empty_tensors
func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
// CHECK: memref.alloc
// CHECK-NOT: memref.alloc
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%concatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%empty_2 = tensor.empty() : tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
// CHECK: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
// CHECK-NOT: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
Expand All @@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {

// CHECK-ELIM-LABEL: func.func @multi_use_of_the_same_tensor_empty
// CHECK-LABEL: func.func @multi_use_of_the_same_tensor_empty
// CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
// CHECK-NOT: memref.alloc
// CHECK-NOT: memref.copy
// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 0]
// CHECK-ELIM: linalg.fill
// CHECK-ELIM: tensor.extract_slice {{.*}}[0, 0, 64]
// CHECK-ELIM: linalg.fill
func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
%cst_1 = arith.constant 1.0 : f32
%cst_2 = arith.constant 2.0 : f32
%cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
%empty_1 = tensor.empty() : tensor<5x6x64xf32>
// CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
// CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
// CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
%res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
%res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
// CHECK: memref.copy
%inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
// CHECK-NOT: memref.copy
%inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_2 : tensor<5x6x128xf32>
Expand Down Expand Up @@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x
: tensor<5x6x64xf32> into tensor<5x6x128xf32>
return %inserted_slice_1 : tensor<5x6x128xf32>
}

// -----

// Test that dependent pure operations are moved before the
// insertion point to enable empty tensor elimination.

// CHECK-LABEL: func.func @move_dependent_arith_op(
// CHECK-SAME: %[[ARG0:.*]]: memref<10xf32>
// CHECK-SAME: %[[ARG1:.*]]: index
// CHECK-NOT: memref.alloc
// CHECK: %[[C5:.*]] = arith.constant 5 : index
// CHECK: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
// CHECK: %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1]
// CHECK: linalg.fill {{.*}} outs(%[[SV]]
// CHECK: return %[[ARG0]]
// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op(
// CHECK-ELIM-SAME: %[[ARG0:.*]]: tensor<10xf32>
// CHECK-ELIM-SAME: %[[ARG1:.*]]: index
// CHECK-ELIM: %[[C5:.*]] = arith.constant 5 : index
// CHECK-ELIM: %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
// CHECK-ELIM: %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1]
// CHECK-ELIM: %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]]
// CHECK-ELIM: tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]]
func.func @move_dependent_arith_op(
%arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
%arg1: index, %f: f32) -> tensor<10xf32>
{
%0 = tensor.empty() : tensor<5xf32>
%1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
%c5 = arith.constant 5 : index
%offset = arith.addi %arg1, %c5 : index
%2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
: tensor<5xf32> into tensor<10xf32>
return %2 : tensor<10xf32>
}

// -----

// Test that side-effecting operations are not moved, preventing empty
// tensor elimination.

// CHECK-LABEL: func.func @side_effecting_op_blocks_movement(
// CHECK: memref.alloc
// CHECK: linalg.fill
// CHECK: memref.load
// CHECK: memref.subview
// CHECK: memref.copy
// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement(
// CHECK-ELIM: tensor.empty
// CHECK-ELIM: linalg.fill
// CHECK-ELIM: memref.load
// CHECK-ELIM: tensor.insert_slice
func.func @side_effecting_op_blocks_movement(
%arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
%mem: memref<index>, %f: f32) -> tensor<10xf32>
{
%0 = tensor.empty() : tensor<5xf32>
%1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
%offset = memref.load %mem[] : memref<index>
%2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
: tensor<5xf32> into tensor<10xf32>
return %2 : tensor<10xf32>
}
Loading
Loading