Skip to content

Conversation

@qedawkins
Copy link
Contributor

Currently empty tensor elimination by constructing a SubsetExtractionOp to match a SubsetInsertionOp at the end of a DPS chain will fail if any operands required by the insertion op don't dominate the insertion point for the extraction op.

This change improves the transformation by attempting to move all pure producers of required operands to the insertion point of the extraction op. In the process this improves a number of tests for empty tensor elimination.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir mlir:bufferization Bufferization infrastructure labels Nov 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2025

@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Quinn Dawkins (qedawkins)

Changes

Currently empty tensor elimination by constructing a SubsetExtractionOp to match a SubsetInsertionOp at the end of a DPS chain will fail if any operands required by the insertion op don't dominate the insertion point for the extraction op.

This change improves the transformation by attempting to move all pure producers of required operands to the insertion point of the extraction op. In the process this improves a number of tests for empty tensor elimination.


Full diff: https://github.com/llvm/llvm-project/pull/169718.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Transforms/RegionUtils.h (+6-2)
  • (modified) mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp (+10-2)
  • (modified) mlir/lib/Transforms/Utils/RegionUtils.cpp (+17-4)
  • (modified) mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir (+80-19)
diff --git a/mlir/include/mlir/Transforms/RegionUtils.h b/mlir/include/mlir/Transforms/RegionUtils.h
index 2ed96afbace81..6a0c94b06c6b2 100644
--- a/mlir/include/mlir/Transforms/RegionUtils.h
+++ b/mlir/include/mlir/Transforms/RegionUtils.h
@@ -85,11 +85,15 @@ LogicalResult moveOperationDependencies(RewriterBase &rewriter, Operation *op,
 /// only for movement of definitions within the same basic block. Note that this
 /// is an all-or-nothing approach. Either definitions of all values are moved
 /// before insertion point, or none of them are.
+/// If `ignoreSideEffect` is set, this will allow movement of all dependent
+/// producers regardless of whether they are side effecting.
 LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
                                    Operation *insertionPoint,
-                                   DominanceInfo &dominance);
+                                   DominanceInfo &dominance,
+                                   bool ignoreSideEffects = true);
 LogicalResult moveValueDefinitions(RewriterBase &rewriter, ValueRange values,
-                                   Operation *insertionPoint);
+                                   Operation *insertionPoint,
+                                   bool ignoreSideEffects = true);
 
 /// Run a set of structural simplifications over the given regions. This
 /// includes transformations like unreachable block elimination, dead argument
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index 1784964cf9b95..0843b4398b24f 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -16,6 +16,7 @@
 #include "mlir/Dialect/Tensor/IR/Tensor.h"
 #include "mlir/IR/Dominance.h"
 #include "mlir/Interfaces/SubsetOpInterface.h"
+#include "mlir/Transforms/RegionUtils.h"
 
 namespace mlir {
 namespace bufferization {
@@ -105,8 +106,15 @@ Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
   // this replacement.
   Operation *insertionPoint =
       findValidInsertionPoint(emptyTensorOp, user, neededValues);
-  if (!insertionPoint)
-    return {};
+  if (!insertionPoint) {
+    // If no already suitable insertion point was found, attempt to move all
+    // needed values before the user.
+    if (failed(moveValueDefinitions(rewriter, neededValues, user,
+                                    /*ignoreSideEffects=*/false))) {
+      return {};
+    }
+    insertionPoint = user;
+  }
 
   rewriter.setInsertionPoint(insertionPoint);
   Value replacement =
diff --git a/mlir/lib/Transforms/Utils/RegionUtils.cpp b/mlir/lib/Transforms/Utils/RegionUtils.cpp
index 31ae1d1895b81..390fc76cc6533 100644
--- a/mlir/lib/Transforms/Utils/RegionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/RegionUtils.cpp
@@ -1145,7 +1145,8 @@ LogicalResult mlir::moveOperationDependencies(RewriterBase &rewriter,
 LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
                                          ValueRange values,
                                          Operation *insertionPoint,
-                                         DominanceInfo &dominance) {
+                                         DominanceInfo &dominance,
+                                         bool ignoreSideEffects) {
   // Remove the values that already dominate the insertion point.
   SmallVector<Value> prunedValues;
   for (auto value : values) {
@@ -1178,8 +1179,14 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
   // Since current support is to only move within a same basic block,
   // the slices dont need to look past block arguments.
   options.omitBlockArguments = true;
+  bool dependsOnSideEffectingOp = false;
   options.filter = [&](Operation *sliceBoundaryOp) {
-    return !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+    bool mustMove =
+        !dominance.properlyDominates(sliceBoundaryOp, insertionPoint);
+    if (mustMove && !isPure(sliceBoundaryOp)) {
+      dependsOnSideEffectingOp = true;
+    }
+    return mustMove;
   };
   llvm::SetVector<Operation *> slice;
   for (auto value : prunedValues) {
@@ -1188,6 +1195,10 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
     (void)result;
   }
 
+  // Check if any operation in the slice is side-effecting.
+  if (!ignoreSideEffects && dependsOnSideEffectingOp)
+    return failure();
+
   // If the slice contains `insertionPoint` cannot move the dependencies.
   if (slice.contains(insertionPoint)) {
     return rewriter.notifyMatchFailure(
@@ -1206,7 +1217,9 @@ LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
 
 LogicalResult mlir::moveValueDefinitions(RewriterBase &rewriter,
                                          ValueRange values,
-                                         Operation *insertionPoint) {
+                                         Operation *insertionPoint,
+                                         bool ignoreSideEffects) {
   DominanceInfo dominance(insertionPoint);
-  return moveValueDefinitions(rewriter, values, insertionPoint, dominance);
+  return moveValueDefinitions(rewriter, values, insertionPoint, dominance,
+                              ignoreSideEffects);
 }
diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
index 8249d59b2374e..3929f5be3b4ef 100644
--- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
+++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir
@@ -368,21 +368,18 @@ func.func @multiple_materialize_in_destination_buffer(%m: memref<5xf32>, %f: f32
 
 // -----
 
-// `EmptyTensorElimination` fails to find a valid insertion
-// point for the new injected `SubsetExtraction`.
-// CHECK-LABEL:   func.func @fail_to_eliminate_any_empty_tensors
-func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
+// CHECK-LABEL:   func.func @eliminate_all_empty_tensors
+func.func @eliminate_all_empty_tensors() -> tensor<5x6x128xf32> {
   %cst_1 = arith.constant 1.0 : f32
   %cst_2 = arith.constant 2.0 : f32
-  // CHECK: memref.alloc
-  // CHECK: memref.alloc
-  // CHECK: memref.alloc
+  // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+  // CHECK-NOT: memref.alloc
   %empty_1 = tensor.empty() : tensor<5x6x64xf32>
   %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
   %empty_2 = tensor.empty() : tensor<5x6x64xf32>
   %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
   %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
-  // CHECK: memref.copy
+  // CHECK-NOT: memref.copy
   %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
@@ -392,20 +389,19 @@ func.func @fail_to_eliminate_any_empty_tensors() -> tensor<5x6x128xf32> {
 
 // -----
 
-// CHECK-LABEL:   func.func @succeed_to_eliminate_one_empty_tensor
-func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
+// CHECK-LABEL:   func.func @eliminate_concatenated_empty_tensors
+func.func @eliminate_concatenated_empty_tensors() -> tensor<5x6x128xf32> {
   %cst_1 = arith.constant 1.0 : f32
   %cst_2 = arith.constant 2.0 : f32
   // CHECK: memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
-  // CHECK: memref.alloc
   // CHECK-NOT: memref.alloc
-  %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
+  %concatenated_empty = tensor.empty() : tensor<5x6x128xf32>
   %empty_1 = tensor.empty() : tensor<5x6x64xf32>
   %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
   %empty_2 = tensor.empty() : tensor<5x6x64xf32>
   %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_2 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
-  // CHECK: memref.copy
-  %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
+  // CHECK-NOT: memref.copy
+  %inserted_slice_1 = tensor.insert_slice %res_1 into %concatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
@@ -420,20 +416,22 @@ func.func @succeed_to_eliminate_one_empty_tensor() -> tensor<5x6x128xf32> {
 
 // CHECK-ELIM-LABEL:   func.func @multi_use_of_the_same_tensor_empty
 // CHECK-LABEL:   func.func @multi_use_of_the_same_tensor_empty
+//       CHECK:   memref.alloc() {alignment = 64 : i64} : memref<5x6x128xf32>
+//   CHECK-NOT:   memref.alloc
+//   CHECK-NOT:   memref.copy
+//  CHECK-ELIM:   tensor.extract_slice {{.*}}[0, 0, 0]
+//  CHECK-ELIM:   linalg.fill
+//  CHECK-ELIM:   tensor.extract_slice {{.*}}[0, 0, 64]
+//  CHECK-ELIM:   linalg.fill
 func.func @multi_use_of_the_same_tensor_empty() -> tensor<5x6x128xf32> {
   %cst_1 = arith.constant 1.0 : f32
   %cst_2 = arith.constant 2.0 : f32
   %cancatenated_empty = tensor.empty() : tensor<5x6x128xf32>
   %empty_1 = tensor.empty() : tensor<5x6x64xf32>
-  // CHECK-ELIM: %[[VAL_3:.*]] = tensor.extract_slice
-  // CHECK-ELIM: linalg.fill ins(%[[VAL_0:.*]] : f32) outs(%[[VAL_3]]
-  // CHECK-ELIM-NOT: linalg.fill ins(%[[VAL_1:.*]] : f32) outs(%[[VAL_3]]
   %res_1 = linalg.fill ins(%cst_1 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
   %res_2 = linalg.fill ins(%cst_2 : f32) outs(%empty_1 : tensor<5x6x64xf32>) -> tensor<5x6x64xf32>
-  // CHECK: memref.copy
   %inserted_slice_1 = tensor.insert_slice %res_1 into %cancatenated_empty[0, 0, 0][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
-  // CHECK-NOT: memref.copy
   %inserted_slice_2 = tensor.insert_slice %res_2 into %inserted_slice_1[0, 0, 64][5, 6, 64][1, 1, 1]
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   return %inserted_slice_2 : tensor<5x6x128xf32>
@@ -476,3 +474,66 @@ func.func @direct_use_of_tensor_empty(%arg0: tensor<5x6x128xf32>) -> tensor<5x6x
       : tensor<5x6x64xf32> into tensor<5x6x128xf32>
   return %inserted_slice_1 : tensor<5x6x128xf32>
 }
+
+// -----
+
+// Test that dependent pure operations are moved before the
+// insertion point to enable empty tensor elimination.
+
+//      CHECK-LABEL: func.func @move_dependent_arith_op(
+//       CHECK-SAME:     %[[ARG0:.*]]: memref<10xf32>
+//       CHECK-SAME:     %[[ARG1:.*]]: index
+//        CHECK-NOT:   memref.alloc
+//            CHECK:   %[[C5:.*]] = arith.constant 5 : index
+//            CHECK:   %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
+//            CHECK:   %[[SV:.*]] = memref.subview %[[ARG0]][%[[OFFSET]]] [5] [1]
+//            CHECK:   linalg.fill {{.*}} outs(%[[SV]]
+//            CHECK:   return %[[ARG0]]
+// CHECK-ELIM-LABEL: func.func @move_dependent_arith_op(
+//  CHECK-ELIM-SAME:     %[[ARG0:.*]]: tensor<10xf32>
+//  CHECK-ELIM-SAME:     %[[ARG1:.*]]: index
+//       CHECK-ELIM:   %[[C5:.*]] = arith.constant 5 : index
+//       CHECK-ELIM:   %[[OFFSET:.*]] = arith.addi %[[ARG1]], %[[C5]]
+//       CHECK-ELIM:   %[[SLICE:.*]] = tensor.extract_slice %[[ARG0]][%[[OFFSET]]] [5] [1]
+//       CHECK-ELIM:   %[[FILL:.*]] = linalg.fill {{.*}} outs(%[[SLICE]]
+//       CHECK-ELIM:   tensor.insert_slice %[[FILL]] into %[[ARG0]][%[[OFFSET]]]
+func.func @move_dependent_arith_op(
+    %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
+    %arg1: index, %f: f32) -> tensor<10xf32>
+{
+  %0 = tensor.empty() : tensor<5xf32>
+  %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+  %c5 = arith.constant 5 : index
+  %offset = arith.addi %arg1, %c5 : index
+  %2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
+      : tensor<5xf32> into tensor<10xf32>
+  return %2 : tensor<10xf32>
+}
+
+// -----
+
+// Test that side-effecting operations are not moved, preventing empty
+// tensor elimination.
+
+//      CHECK-LABEL: func.func @side_effecting_op_blocks_movement(
+//            CHECK:   memref.alloc
+//            CHECK:   linalg.fill
+//            CHECK:   memref.load
+//            CHECK:   memref.subview
+//            CHECK:   memref.copy
+// CHECK-ELIM-LABEL: func.func @side_effecting_op_blocks_movement(
+//       CHECK-ELIM:   tensor.empty
+//       CHECK-ELIM:   linalg.fill
+//       CHECK-ELIM:   memref.load
+//       CHECK-ELIM:   tensor.insert_slice
+func.func @side_effecting_op_blocks_movement(
+    %arg0: tensor<10xf32> {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true},
+    %mem: memref<index>, %f: f32) -> tensor<10xf32>
+{
+  %0 = tensor.empty() : tensor<5xf32>
+  %1 = linalg.fill ins(%f : f32) outs(%0 : tensor<5xf32>) -> tensor<5xf32>
+  %offset = memref.load %mem[] : memref<index>
+  %2 = tensor.insert_slice %1 into %arg0[%offset][5][1]
+      : tensor<5xf32> into tensor<10xf32>
+  return %2 : tensor<10xf32>
+}

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you have a use case for ignoreSideEffects = true. It is probably better to have ignoreSideEffects = false by default (or just not change the API and make sure that we always account for side-effecting operations).

@qedawkins
Copy link
Contributor Author

Did you have a use case for ignoreSideEffects = true. It is probably better to have ignoreSideEffects = false by default (or just not change the API and make sure that we always account for side-effecting operations).

I did not, it was just an attempt to keep the API NFC. I can happily drop the flag and default it to true then.

qedawkins and others added 2 commits December 5, 2025 10:52
…ty-tensors

Currently empty tensor elimination by constructing a SubsetExtractionOp
to match a SubsetInsertionOp at the end of a DPS chain will fail if any
operands required by the insertion op don't dominate the insertion point
for the extraction op.

This change improves the transformation by attempting to move all pure
producers of required operands to the insertion point of the extraction
op. In the process this improves a number of tests for empty tensor
elimination.
@qedawkins qedawkins force-pushed the eliminate-empty-tensors branch from 6f69946 to 894db9b Compare December 5, 2025 15:53
@qedawkins qedawkins merged commit bb17dfa into llvm:main Dec 5, 2025
10 checks passed
@qedawkins qedawkins deleted the eliminate-empty-tensors branch December 5, 2025 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants