Skip to content

[MLIR] Make generic skip packing init operand when not used in DataLayoutPropagation #146139

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 29 additions & 10 deletions mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,19 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
return newGenericOp;
}

static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) {
Block *block = genericOp.getBody();
int numBlockArgs = block->getNumArguments();
int numDpsOuts = genericOp.getNumDpsInits();
int initArgStartIndex = numBlockArgs - numDpsOuts;
for (int i = 0; i < numDpsOuts; ++i) {
int matchingInitArgIndex = initArgStartIndex + i;
if (!block->getArgument(matchingInitArgIndex).use_empty())
return false;
}
return true;
}
Comment on lines +362 to +372
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This can be simplified with getDpsInitsMutable and getMatchingBlockArgument.

return llvm::all_of(genericOp.getDpsInitsMutable(),
                    [&](OpOperand &operand) {
                      return genericOp.getMatchingBlockArgument(&operand).use_empty();
                    });


/// Bubbles up linalg.pack op through a producer generic op. This
/// swap pack(generic) to generic(pack). The new generic op works on packed
/// domain; pack ops are created for input and output operands. E.g.,
Expand Down Expand Up @@ -470,12 +483,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo,
genericOp, opOperand);

// If the dps init operand of the generic is a tensor.empty forward the pack
// op destination.
// Forward the new tensor.empty as a destination if it is one of the following
// situations:
// 1) The dps init operand is a tensor.empty.
// 2) The dps init is a write-only operand, i.e., it is not used in the
// genericOp
Value dest = packedOutOperand;
if (auto initTensor = genericOp.getDpsInitOperand(0)
->get()
.getDefiningOp<tensor::EmptyOp>()) {
auto initTensor =
genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
if (initTensor || isGenericOutsNotUsed(genericOp)) {
dest = packOpDest;
}
// pack(unpack) isn't naively foldable because the unpack op can be from
Expand Down Expand Up @@ -1101,12 +1117,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp,
genericOp, genericOp.getDpsInitOperand(0));
auto destPack = packedOutOperand.getDefiningOp<linalg::PackOp>();

// If the dps init operand of the generic is a tensor.empty, do not pack it
// and forward the new tensor.empty as a destination.
// Forward the new tensor.empty as a destination if it is one of the following
// situations:
// 1) The dps init operand is a tensor.empty.
// 2) The dps init is a write-only operand, i.e., it is not used in the
// genericOp
Value dest = packedOutOperand;
if (auto initTensor = genericOp.getDpsInitOperand(0)
->get()
.getDefiningOp<tensor::EmptyOp>()) {
auto initTensor =
genericOp.getDpsInitOperand(0)->get().getDefiningOp<tensor::EmptyOp>();
if (initTensor || isGenericOutsNotUsed(genericOp)) {
if (destPack)
dest = destPack.getDest();
}
Expand Down
52 changes: 40 additions & 12 deletions mlir/test/Dialect/Linalg/data-layout-propagation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten

// -----

#map0 = affine_map<(d0, d1) -> (d0, d1)>
func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{
%elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]}
ins(%arg0 : tensor<128x256xi32>)
outs(%init : tensor<128x256xi32>) {
^bb0(%arg3: i32, %arg4: i32):
%4 = arith.addi %arg3, %arg3 : i32
linalg.yield %4 : i32
} -> tensor<128x256xi32>
%empty = tensor.empty() : tensor<16x4x32x16xi32>
%pack = linalg.pack %elem
outer_dims_perm = [1, 0]
inner_dims_pos = [0, 1]
inner_tiles = [32, 16]
into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32>
return %pack : tensor<16x4x32x16xi32>
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32>
// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]]
// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16]
// CHECK-SAME: into %[[ARG0_EMPTY]]
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[PACKED_ARG0]]
// CHECK-SAME: outs(%[[ARG1_EMPTY]]

// -----

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> {
Expand Down Expand Up @@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56

#map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> {
%0 = tensor.empty() : tensor<12x56x56x64xf32>
%1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32>
%2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) {
Expand All @@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t
}

// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
// CHECK-LABEL: func.func @unpack_element_type_change
// CHECK-LABEL: func.func @unpack_element_type_change_no_use
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1_PACK_EMPTY]]
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16>
// CHECK: %[[RES:.+]] = linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]]
// CHECK-SAME: ins(%[[ARG0]]
// CHECK-SAME: outs(%[[ARG1_PACK]]
// CHECK-SAME: outs(%[[EMPTY]]
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]]
// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32]
// CHECK-SAME: into %[[ARG1]]
Expand Down Expand Up @@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]
// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty
// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]]
// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8]
// CHECK-SAME: into %[[ARG2_PACK_EMPTY]]
// CHECK: %[[EMPTY:.+]] = tensor.empty
// CHECK: %[[GENERIC:.+]] = linalg.generic
// CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>)
// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor<?x8x4x8xbf16>)
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x8x4x8xbf16>)
// CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]]
// CHECK-SAME: into %[[ARG2]]
// CHECK: return %[[UNPACK]] : tensor<?x64xbf16>
Expand Down
Loading