diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 7188987e5e938..31ac87bacf267 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -358,6 +358,12 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp, return newGenericOp; } +static bool isGenericOutsNotUsed(linalg::GenericOp genericOp) { + return llvm::all_of(genericOp.getDpsInitsMutable(), [&](OpOperand &operand) { + return genericOp.getMatchingBlockArgument(&operand).use_empty(); + }); +} + /// Bubbles up linalg.pack op through a producer generic op. This /// swap pack(generic) to generic(pack). The new generic op works on packed /// domain; pack ops are created for input and output operands. E.g., @@ -470,12 +476,15 @@ bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp, getOrCreatePackedViewOfOperand(rewriter, genericOp.getLoc(), *packInfo, genericOp, opOperand); - // If the dps init operand of the generic is a tensor.empty forward the pack - // op destination. + // Forward the new tensor.empty as a destination if it is one of the following + // situations: + // 1) The dps init operand is a tensor.empty. + // 2) The dps init is a write-only operand, i.e., it is not used in the + // genericOp Value dest = packedOutOperand; - if (auto initTensor = genericOp.getDpsInitOperand(0) - ->get() - .getDefiningOp()) { + auto initTensor = + genericOp.getDpsInitOperand(0)->get().getDefiningOp(); + if (initTensor || isGenericOutsNotUsed(genericOp)) { dest = packOpDest; } // pack(unpack) isn't naively foldable because the unpack op can be from @@ -1101,12 +1110,15 @@ pushDownUnPackOpThroughGenericOp(RewriterBase &rewriter, GenericOp genericOp, genericOp, genericOp.getDpsInitOperand(0)); auto destPack = packedOutOperand.getDefiningOp(); - // If the dps init operand of the generic is a tensor.empty, do not pack it - // and forward the new tensor.empty as a destination. + // Forward the new tensor.empty as a destination if it is one of the following + // situations: + // 1) The dps init operand is a tensor.empty. + // 2) The dps init is a write-only operand, i.e., it is not used in the + // genericOp Value dest = packedOutOperand; - if (auto initTensor = genericOp.getDpsInitOperand(0) - ->get() - .getDefiningOp()) { + auto initTensor = + genericOp.getDpsInitOperand(0)->get().getDefiningOp(); + if (initTensor || isGenericOutsNotUsed(genericOp)) { if (destPack) dest = destPack.getDest(); } diff --git a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir index 31c9e9ed3c501..6fc8d9f152f4e 100644 --- a/mlir/test/Dialect/Linalg/data-layout-propagation.mlir +++ b/mlir/test/Dialect/Linalg/data-layout-propagation.mlir @@ -435,6 +435,40 @@ func.func @elem_pack_transpose_outer_dims(%arg0: tensor<128x256xi32>, %init: ten // ----- +#map0 = affine_map<(d0, d1) -> (d0, d1)> +func.func @elem_pack_transpose_outer_dims_unused_init(%arg0: tensor<128x256xi32>, %init: tensor<128x256xi32>) -> tensor<16x4x32x16xi32>{ + %elem = linalg.generic {indexing_maps = [#map0, #map0], iterator_types = ["parallel", "parallel"]} + ins(%arg0 : tensor<128x256xi32>) + outs(%init : tensor<128x256xi32>) { + ^bb0(%arg3: i32, %arg4: i32): + %4 = arith.addi %arg3, %arg3 : i32 + linalg.yield %4 : i32 + } -> tensor<128x256xi32> + %empty = tensor.empty() : tensor<16x4x32x16xi32> + %pack = linalg.pack %elem + outer_dims_perm = [1, 0] + inner_dims_pos = [0, 1] + inner_tiles = [32, 16] + into %empty : tensor<128x256xi32> -> tensor<16x4x32x16xi32> + return %pack : tensor<16x4x32x16xi32> +} + +// CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> +// CHECK-LABEL: func.func @elem_pack_transpose_outer_dims +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] +// CHECK: %[[ARG1_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[ARG0_EMPTY:.+]] = tensor.empty() : tensor<16x4x32x16xi32> +// CHECK: %[[PACKED_ARG0:.+]] = linalg.pack %[[ARG0]] +// CHECK-SAME: outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [32, 16] +// CHECK-SAME: into %[[ARG0_EMPTY]] +// CHECK: %[[RES:.+]] = linalg.generic +// CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] +// CHECK-SAME: ins(%[[PACKED_ARG0]] +// CHECK-SAME: outs(%[[ARG1_EMPTY]] + +// ----- + #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> func.func @unpack_on_output(%arg0: tensor<12x2x56x56x32xf32>) -> tensor<12x56x56x64xf32> { @@ -497,7 +531,7 @@ func.func @unpack_on_input(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56 #map = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> -func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> { +func.func @unpack_element_type_change_no_use(%arg0: tensor<12x2x56x56x32xf32>, %init: tensor<12x56x56x64xf16>) -> tensor<12x56x56x64xf16> { %0 = tensor.empty() : tensor<12x56x56x64xf32> %1 = linalg.unpack %arg0 outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] into %0 : tensor<12x2x56x56x32xf32> -> tensor<12x56x56x64xf32> %2 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%1: tensor<12x56x56x64xf32>) outs(%init : tensor<12x56x56x64xf16>) { @@ -509,17 +543,14 @@ func.func @unpack_element_type_change(%arg0: tensor<12x2x56x56x32xf32>, %init: t } // CHECK: #[[$MAP:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)> -// CHECK-LABEL: func.func @unpack_element_type_change +// CHECK-LABEL: func.func @unpack_element_type_change_no_use // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] -// CHECK: %[[ARG1_PACK_EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16> -// CHECK: %[[ARG1_PACK:.+]] = linalg.pack %[[ARG1]] -// CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] -// CHECK-SAME: into %[[ARG1_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<12x2x56x56x32xf16> // CHECK: %[[RES:.+]] = linalg.generic // CHECK-SAME: indexing_maps = [#[[$MAP]], #[[$MAP]]] // CHECK-SAME: ins(%[[ARG0]] -// CHECK-SAME: outs(%[[ARG1_PACK]] +// CHECK-SAME: outs(%[[EMPTY]] // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[RES]] // CHECK-SAME: outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3] inner_tiles = [32] // CHECK-SAME: into %[[ARG1]] @@ -1402,13 +1433,10 @@ func.func @push_unpack_in_padded_domain_foldable(%arg0: tensor<8x8x4x8xf32>, %de // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]] // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]] -// CHECK: %[[ARG2_PACK_EMPTY:.+]] = tensor.empty -// CHECK: %[[ARG2_PACK:.+]] = linalg.pack %[[ARG2]] -// CHECK-SAME: inner_dims_pos = [0, 1] inner_tiles = [4, 8] -// CHECK-SAME: into %[[ARG2_PACK_EMPTY]] +// CHECK: %[[EMPTY:.+]] = tensor.empty // CHECK: %[[GENERIC:.+]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]] : tensor<8x8x4x8xf32>) -// CHECK-SAME: outs(%[[ARG2_PACK]] : tensor) +// CHECK-SAME: outs(%[[EMPTY]] : tensor) // CHECK: %[[UNPACK:.+]] = linalg.unpack %[[GENERIC]] // CHECK-SAME: into %[[ARG2]] // CHECK: return %[[UNPACK]] : tensor