diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index d64f94a49f781..4360055e78691 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1402,6 +1402,7 @@ def PromoteOp : Op:$operands_to_promote, DefaultValuedAttr:$use_full_tile_buffers, UnitAttr:$use_full_tiles_by_default, + UnitAttr:$use_original_subview_size, UnitAttr:$use_alloca, OptionalAttr:$memory_space, OptionalAttr:$mapping, diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 2b4855f49695c..8d2b5a240524c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -422,6 +422,13 @@ struct LinalgPromotionOptions { useFullTileBuffersDefault = use; return *this; } + /// If true, buffers will be allocated with the original subview size. This + /// may result in more dynamic allocations, in case of dynamic sizes. + bool useOriginalSubviewSize = false; + LinalgPromotionOptions &setUseOriginalSubviewSize(bool originalSize) { + useOriginalSubviewSize = originalSize; + return *this; + } /// Alignment of promoted buffer. If `std::nullopt` do not specify alignment. std::optional alignment; LinalgPromotionOptions &setAlignment(unsigned align) { @@ -796,7 +803,8 @@ FailureOr specializeGenericOp(RewriterBase &rewriter, GenericOp genericOp); /// Create a new buffer using the `allocationFn` provided. The size of this -/// buffer is the smallest constant bounding size along each dimension that +/// buffer is either the original subview size when 'useOriginalSubviewSize' is +/// set to true or the smallest constant bounding size along each dimension that /// can be computed for the size of the result of `subView`. Returns the /// allocated buffer as `fullLocalView` and the view that matches the size of /// the result of subview operation as `partialLocalView`. @@ -806,6 +814,7 @@ struct PromotionInfo { }; FailureOr promoteSubviewAsNewBuffer(OpBuilder &b, Location loc, memref::SubViewOp subView, + bool useOriginalSubviewSize, const AllocBufferCallbackFn &allocationFn, DataLayout &layout); diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 8571d641e26d1..d0031e047b770 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -2408,6 +2408,9 @@ transform::PromoteOp::applyToOne(transform::TransformRewriter &rewriter, if (getUseFullTilesByDefault()) promotionOptions = promotionOptions.setUseFullTileBuffersByDefault( getUseFullTilesByDefault()); + if (getUseOriginalSubviewSize()) + promotionOptions = + promotionOptions.setUseOriginalSubviewSize(getUseOriginalSubviewSize()); if (getUseAlloca()) promotionOptions = promotionOptions.setUseAlloca(getUseAlloca()); if (!getUseFullTileBuffers().empty()) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp index d599ddd220dde..80b37544a99f8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -148,6 +148,9 @@ struct LinalgOpInstancePromotionOptions { llvm::SmallSet operandsNumbersToCopyIn; /// True if the full view should be used for the promoted buffer. DenseMap useFullTileBuffers; + /// True if the original subview size should be used. This means the full tile + /// buffer is the same size as the partial view. + bool useOriginalSubviewSize; /// Callback functions for allocation and deallocation of promoted buffers, as /// well as to copy the data into and out of these buffers. @@ -170,6 +173,7 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( options.useFullTileBuffers.value_or(llvm::SmallBitVector()); vUseFullTileBuffers.resize(linalgOp->getNumOperands(), options.useFullTileBuffersDefault); + useOriginalSubviewSize = options.useOriginalSubviewSize; for (OpOperand &opOperand : linalgOp->getOpOperands()) { int64_t operandNumber = opOperand.getOperandNumber(); @@ -237,7 +241,8 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions( // by a partial `copy` op. FailureOr mlir::linalg::promoteSubviewAsNewBuffer( OpBuilder &b, Location loc, memref::SubViewOp subView, - const AllocBufferCallbackFn &allocationFn, DataLayout &layout) { + bool useOriginalSubviewSize, const AllocBufferCallbackFn &allocationFn, + DataLayout &layout) { auto viewType = subView.getType(); auto rank = viewType.getRank(); SmallVector fullSizes; @@ -254,7 +259,8 @@ FailureOr mlir::linalg::promoteSubviewAsNewBuffer( // to look for the bound. LLVM_DEBUG(llvm::dbgs() << "Extract tightest: " << rangeValue.size << "\n"); Value size; - if (auto attr = llvm::dyn_cast_if_present(rangeValue.size)) { + if (llvm::isa_and_present(rangeValue.size) || + useOriginalSubviewSize) { size = getValueOrCreateConstantIndexOp(b, loc, rangeValue.size); } else { FailureOr upperBound = @@ -295,7 +301,8 @@ promoteSubViews(ImplicitLocOpBuilder &b, memref::SubViewOp subView = cast(v.second.getDefiningOp()); auto promotionInfo = promoteSubviewAsNewBuffer( - b, b.getLoc(), subView, options.allocationFn, layout); + b, b.getLoc(), subView, options.useOriginalSubviewSize, + options.allocationFn, layout); if (failed(promotionInfo)) return failure(); promotionInfoMap[v.first] = *promotionInfo; diff --git a/mlir/test/Dialect/Linalg/promotion_options.mlir b/mlir/test/Dialect/Linalg/promotion_options.mlir index caa72ba24316f..dbc073c2665f9 100644 --- a/mlir/test/Dialect/Linalg/promotion_options.mlir +++ b/mlir/test/Dialect/Linalg/promotion_options.mlir @@ -42,3 +42,59 @@ module attributes {transform.with_named_sequence} { transform.yield } } + +// ----- + +func.func @matmul_f32(%A: memref<512x256xf32>, %B: memref<256x512xf32>, %C: memref<256x256xf32>, %s0: index, %s1: index, %s2: index) { + %c0 = arith.constant 0 : index + %c256 = arith.constant 256 : index + %c512 = arith.constant 512 : index + scf.for %arg4 = %c0 to %c512 step %s0 { + scf.for %arg5 = %c0 to %c512 step %s1 { + scf.for %arg6 = %c0 to %c256 step %s2 { + %i0 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg4)[%s0] + %i1 = affine.min affine_map<(d0)[s0] -> (-d0 + 512, s0)>(%arg5)[%s1] + %i2 = affine.min affine_map<(d0)[s0] -> (-d0 + 256, s0)>(%arg6)[%s2] + %0 = memref.subview %A[%arg4, %arg6][%i0, %i2][1, 1] : memref<512x256xf32> to memref> + %1 = memref.subview %B[%arg6, %arg5][%i2, %i1][1, 1] : memref<256x512xf32> to memref> + %2 = memref.subview %C[%arg4, %arg5][%i0, %i1][1, 1] : memref<256x256xf32> to memref> + linalg.matmul + ins(%0, %1: memref>, + memref>) + outs(%2: memref>) + } + } + } + return +} + +// CHECK-LABEL: func.func @matmul_f32( +// CHECK-SAME: %[[ARG0:.*]]: memref<512x256xf32> +// CHECK-SAME: %[[ARG1:.*]]: memref<256x512xf32> +// CHECK-SAME: %[[ARG2:.*]]: memref<256x256xf32> +// CHECK-SAME: %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index +// CHECK: %[[C4:.*]] = arith.constant 4 : index + +// CHECK: %[[i0:.*]] = affine.min +// CHECK: %[[i1:.*]] = affine.min +// CHECK: %[[i2:.*]] = affine.min + +// CHECK: %[[VAL_13:.*]] = arith.muli %[[i0]], %[[i2]] : index +// CHECK: %[[VAL_14:.*]] = arith.muli %[[VAL_13]], %[[C4]] : index +// CHECK: %[[VAL_15:.*]] = memref.alloc(%[[VAL_14]]) : memref + +// CHECK: %[[VAL_18:.*]] = arith.muli %[[i2]], %[[i1]] : index +// CHECK: %[[VAL_19:.*]] = arith.muli %[[VAL_18]], %[[C4]] : index +// CHECK: %[[VAL_20:.*]] = memref.alloc(%[[VAL_19]]) : memref + +// CHECK: %[[VAL_23:.*]] = arith.muli %[[i0]], %[[i1]] : index +// CHECK: %[[VAL_24:.*]] = arith.muli %[[VAL_23]], %[[C4]] : index +// CHECK: %[[VAL_25:.*]] = memref.alloc(%[[VAL_24]]) : memref + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.matmul"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.promote %0 { use_original_subview_size } : (!transform.any_op) -> !transform.any_op + transform.yield + } +}