-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation #146139
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir-linalg Author: Zhuoran Yin (jerryyin) ChangesIn both Full diff: https://github.com/llvm/llvm-project/pull/146139.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7188987e5e938..3b8ed6bfb6e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -358,6 +358,19 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
return newGenericOp;
}
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+ Block *block = genericOp.getBody();
+ int numBlockArgs = block->getNumArguments();
+ int numDpsOuts = genericOp.getNumDpsInits();
+ int initArgStartIndex = numBlockArgs - numDpsOuts;
+ for (int i = 0; i < numDpsOuts; ++i) {
+ int matchingInitArgIndex = initArgStartIndex + i;
+ if (!block->getArgument(matchingInitArgIndex).use_empty())
+ return false;
+ }
+ return true;
+}
+
/// Bubbles up linalg.pack op through a producer generic op. This
/// swap pack(generic) to generic(pack). The new generic op works on packed
/// domain; pack ops are created for input and output operands. E.g.,
@@ -470,12 +483,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, opOperand);
- // If the dps init operand of the generic is a tensor.empty forward the pack
- // op destination.
+ // Forward the new tensor.empty as a destination if it is one of the following
+ // situations:
+ // 1) The dps init operand is a tensor.empty.
+ // 2) The dps init is a write-only operand, i.e., it is not used in the
+ // genericOp
Value dest = packedOutOperand;
- if (auto initTensor = genericOp.getDpsInitOperand(0)
- ->get()
- .getDefiningOp<tensor::EmptyOp>()) {
+ auto initTensor =
+ genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+ if (initTensor || isGenericOutsNotUsed(genericOp)) {
dest = packOpDest;
}
// pack(unpack) isn't naively foldable because the unpack op can be from
@@ -1101,12 +1117,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
genericOp, genericOp.getDpsInitOperand(0));
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
- // If the dps init operand of the generic is a tensor.empty, do not pack it
- // and forward the new tensor.empty as a destination.
+ // Forward the new tensor.empty as a destination if it is one of the following
+ // situations:
+ // 1) The dps init operand is a tensor.empty.
+ // 2) The dps init is a write-only operand, i.e., it is not used in the
+ // genericOp
Value dest = packedOutOperand;
- if (auto initTensor = genericOp.getDpsInitOperand(0)
- ->get()
- .getDefiningOp<tensor::EmptyOp>()) {
+ auto initTensor =
+ genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+ if (initTensor || isGenericOutsNotUsed(genericOp)) {
if (destPack)
dest = destPack.getDest();
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 31c9e9ed3c501..6fc8d9f152f4e 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
// -----
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
+ %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<128x256xi32>)
+ outs(%init : tensor<128x256xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %4 = arith.addi %arg3, %arg3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<128x256xi32>
+ %empty = tensor.empty() : tensor<16x4x32x16xi32>
+ %pack = linalg.pack %elem
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [32, 16]
+ into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+ return %pack : tensor<16x4x32x16xi32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: outs(%[[ARG1_EMPTY]]
+
+// -----
+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
@@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
+func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
%0 = tensor.empty() : tensor<12x56x56x64xf32>
%1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
@@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
}
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK-LABEL: func.func @unpack_element_type_change
+// CHECK-LABEL: func.func @unpack_element_type_change_no_use
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
-// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
-// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
-// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG2]]
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
|
@llvm/pr-subscribers-mlir Author: Zhuoran Yin (jerryyin) ChangesIn both Full diff: https://github.com/llvm/llvm-project/pull/146139.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 7188987e5e938..3b8ed6bfb6e6f 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -358,6 +358,19 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
return newGenericOp;
}
+static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
+ Block *block = genericOp.getBody();
+ int numBlockArgs = block->getNumArguments();
+ int numDpsOuts = genericOp.getNumDpsInits();
+ int initArgStartIndex = numBlockArgs - numDpsOuts;
+ for (int i = 0; i < numDpsOuts; ++i) {
+ int matchingInitArgIndex = initArgStartIndex + i;
+ if (!block->getArgument(matchingInitArgIndex).use_empty())
+ return false;
+ }
+ return true;
+}
+
/// Bubbles up linalg.pack op through a producer generic op. This
/// swap pack(generic) to generic(pack). The new generic op works on packed
/// domain; pack ops are created for input and output operands. E.g.,
@@ -470,12 +483,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, opOperand);
- // If the dps init operand of the generic is a tensor.empty forward the pack
- // op destination.
+ // Forward the new tensor.empty as a destination if it is one of the following
+ // situations:
+ // 1) The dps init operand is a tensor.empty.
+ // 2) The dps init is a write-only operand, i.e., it is not used in the
+ // genericOp
Value dest = packedOutOperand;
- if (auto initTensor = genericOp.getDpsInitOperand(0)
- ->get()
- .getDefiningOp<tensor::EmptyOp>()) {
+ auto initTensor =
+ genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+ if (initTensor || isGenericOutsNotUsed(genericOp)) {
dest = packOpDest;
}
// pack(unpack) isn't naively foldable because the unpack op can be from
@@ -1101,12 +1117,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
genericOp, genericOp.getDpsInitOperand(0));
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();
- // If the dps init operand of the generic is a tensor.empty, do not pack it
- // and forward the new tensor.empty as a destination.
+ // Forward the new tensor.empty as a destination if it is one of the following
+ // situations:
+ // 1) The dps init operand is a tensor.empty.
+ // 2) The dps init is a write-only operand, i.e., it is not used in the
+ // genericOp
Value dest = packedOutOperand;
- if (auto initTensor = genericOp.getDpsInitOperand(0)
- ->get()
- .getDefiningOp<tensor::EmptyOp>()) {
+ auto initTensor =
+ genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
+ if (initTensor || isGenericOutsNotUsed(genericOp)) {
if (destPack)
dest = destPack.getDest();
}
diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
index 31c9e9ed3c501..6fc8d9f152f4e 100644
--- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
+++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir
@@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten
// -----
+#map0 = affine_map<(d0, d1) -> (d0, d1)>
+func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
+ %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
+ ins(%arg0 : tensor<128x256xi32>)
+ outs(%init : tensor<128x256xi32>) {
+ ^bb0(%arg3: i32, %arg4: i32):
+ %4 = arith.addi %arg3, %arg3 : i32
+ linalg.yield %4 : i32
+ } -> tensor<128x256xi32>
+ %empty = tensor.empty() : tensor<16x4x32x16xi32>
+ %pack = linalg.pack %elem
+ outer_dims_perm = [1, 0]
+ inner_dims_pos = [0, 1]
+ inner_tiles = [32, 16]
+ into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
+ return %pack : tensor<16x4x32x16xi32>
+}
+
+// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
+// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
+// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
+// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
+// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
+// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
+// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
+// CHECK-SAME: into %[[ARG0_EMPTY]]
+// CHECK: %[[RES:.+]] = linalg.generic
+// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
+// CHECK-SAME: ins(%[[PACKED_ARG0]]
+// CHECK-SAME: outs(%[[ARG1_EMPTY]]
+
+// -----
+
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
@@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56
#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
-func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
+func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
%0 = tensor.empty() : tensor<12x56x56x64xf32>
%1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
@@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
}
// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
-// CHECK-LABEL: func.func @unpack_element_type_change
+// CHECK-LABEL: func.func @unpack_element_type_change_no_use
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
-// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
-// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
-// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0]]
-// CHECK-SAME: outs(%[[ARG1_PACK]]
+// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
@@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
-// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
-// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
-// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
-// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
+// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
-// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
+// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG2]]
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
|
In both
bubbleUpPackOpThroughGenericOp()
orpushDownUnPackOpThroughGenericOp()
, we can simplify the lowered IR by removing the pack of an empty when the init tensor isn't used in generic op. Instead of packing an empty tensor, the empty tensor can be forwarded to the generic output. This allows cleaner result after data layout propagation.