diff --git a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h index 33a43ce2ee7bb..3ca1bdd0fbf76 100644 --- a/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h +++ b/mlir/include/mlir/Dialect/SCF/Transforms/TileUsingInterface.h @@ -124,6 +124,17 @@ struct SCFTilingOptions { mappingVector = llvm::to_vector(mapping); return *this; } + + /// Gives hints for whether the tile sizes divide the iteration space evenly. + /// For static sizes, this is trivially verifiable (and the helpers here take + /// advantage of that), however for dynamic sizes we are always forced to be + /// pessimistic. This allows external analysis to check for divisibility and + /// pass on the info to tiling. + SmallVector divisibilityHint = {}; + SCFTilingOptions &setDivisibilityHint(ArrayRef hint) { + divisibilityHint.assign(hint.begin(), hint.end()); + return *this; + } }; /// Transformation information returned after tiling. diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 7edf19689d2e1..16edb1d8c6fce 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -248,7 +248,8 @@ static std::tuple, SmallVector> getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, ArrayRef iterationDomain, ArrayRef tileSizes, - ArrayRef numThreads) { + ArrayRef numThreads, + ArrayRef divisibilityHint) { SmallVector offsets, sizes; int materializedLoopNum = 0; @@ -260,8 +261,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, offsetExpr = d0 + d1 * s0; residualTileSizeExpr = s1 - (d0 + d1 * s0); - for (auto [nt, tileSize, loopRange] : - llvm::zip_equal(numThreads, tileSizes, iterationDomain)) { + for (auto [nt, tileSize, loopRange, divHint] : llvm::zip_equal( + numThreads, tileSizes, iterationDomain, divisibilityHint)) { // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. @@ -280,7 +281,7 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, {loopRange.offset, nt, tileSize, loopRange.size}); OpFoldResult size = tileSize; - if (!isConstantIntValue(residualTileSize, 0)) { + if (!isConstantIntValue(residualTileSize, 0) && !divHint) { OpFoldResult sizeMinusOffsetPerThread = affine::makeComposedFoldedAffineApply(rewriter, loc, s0 - d0, {offset, loopRange.size}); @@ -299,7 +300,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, // `nonNegativeTileSize = affine.max(0, tileSize)`. // This `max` can be avoided if // `offset + tileSize * (numThreads - 1) < (ub - lb)` - if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size)) { + if (!canOmitTileOffsetInBoundsCheck(tileSize, nt, loopRange.size) && + !divHint) { AffineMap maxMap = AffineMap::getMultiDimIdentityMap(2, rewriter.getContext()); size = affine::makeComposedFoldedAffineMax( @@ -311,8 +313,8 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, } return {offsets, sizes}; } else { - for (auto [tileSize, loopRange] : - llvm::zip_equal(tileSizes, iterationDomain)) { + for (auto [tileSize, loopRange, divHint] : + llvm::zip_equal(tileSizes, iterationDomain, divisibilityHint)) { // Non-tiled cases, set the offset and size to the // `loopRange.offset/size`. @@ -325,8 +327,9 @@ getTileOffsetAndSizes(RewriterBase &rewriter, Location loc, ValueRange ivs, Value iv = ivs[materializedLoopNum++]; OpFoldResult offset = getAsOpFoldResult(iv); offsets.push_back(offset); - OpFoldResult size = - getBoundedTileSize(rewriter, loc, loopRange, offset, tileSize); + OpFoldResult size = divHint ? tileSize + : getBoundedTileSize(rewriter, loc, loopRange, + offset, tileSize); sizes.push_back(size); } return {offsets, sizes}; @@ -950,6 +953,11 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, std::tie(tileSizes, numThreads) = getUserTileSizesAndNumThreads(rewriter, op, iterationDomain, options); + // 2a. Pad the divisibility hints to the domain rank. + SmallVector divisibilityHint = options.divisibilityHint; + divisibilityHint.append(iterationDomain.size() - divisibilityHint.size(), + false); + // Check if it is safe to tile. This is hold over from previous iterations // of tile to for-all. Consider dropping it. if (options.loopType == scf::SCFTilingOptions::LoopType::ForallOp) { @@ -982,8 +990,9 @@ mlir::scf::tileUsingSCF(RewriterBase &rewriter, TilingInterface op, -> LogicalResult { // 4a. Compute the `offsets` and `sizes` to use for tiling. SmallVector offsets, sizes; - std::tie(offsets, sizes) = getTileOffsetAndSizes( - rewriter, loc, ivs, iterationDomain, tileSizes, numThreads); + std::tie(offsets, sizes) = + getTileOffsetAndSizes(rewriter, loc, ivs, iterationDomain, tileSizes, + numThreads, divisibilityHint); // 4b. If interchange was provided, apply inverse of the interchange // to get back the offsets/sizes in the order to be specified. diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir index 745a82fc0da75..558ca798fffc2 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir @@ -349,3 +349,99 @@ module attributes {transform.with_named_sequence} { // CHECK-LABEL: func @check_scalar_memref_operation // CHECK-NOT: scf.for // CHECK: linalg.generic + +// ----- + +func.func @simple_matmul_assume_divisible_n(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.tile_using_forall %matmul [10, 20] + divisibility_hint = [false, true] mapping = [#gpu.block, #gpu.block] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 10)> +// CHECK: func.func @simple_matmul_assume_divisible_n( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]]) +// CHECK: %[[TS_Y:.+]] = affine.min #[[MAP0]](%[[IV0]])[%[[M]]] +// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0] [%[[TS_Y]], %[[K]]] [1, 1] +// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [0, %[[IV1]]] [%[[K]], 20] [1, 1] +// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], 20] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : tensor, tensor +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [%[[TS_Y]], 20] [1, 1] +// CHECK: mapping = [#gpu.block, #gpu.block] +// CHECK: return %[[RESULT]] + +// ----- + +func.func @simple_matmul_extend_divisibility(%arg0 : tensor, %arg1 : tensor, + %arg2 : tensor) -> tensor { + %0 = linalg.matmul + ins(%arg0, %arg1 : tensor, tensor) + outs(%arg2 : tensor) -> tensor + return %0 : tensor +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.tile_using_forall %matmul [10, 20] + divisibility_hint = [true] mapping = [#gpu.block, #gpu.block] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 20)> +// CHECK: func.func @simple_matmul_extend_divisibility( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor +// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] +// CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] +// CHECK: %[[RESULT:.+]] = scf.forall (%[[IV0:[a-zA-Z0-9]+]], %[[IV1:[a-zA-Z0-9]+]]) = +// CHECK-SAME: (0, 0) to (%[[M]], %[[N]]) step (10, 20) shared_outs(%[[INIT:.+]] = %[[ARG2]]) +// CHECK: %[[TS_X:.+]] = affine.min #[[MAP0]](%[[IV1]])[%[[N]]] +// CHECK: %[[LHS_TILE:.+]] = tensor.extract_slice %[[ARG0]] +// CHECK-SAME: [%[[IV0]], 0] [10, %[[K]]] [1, 1] +// CHECK: %[[RHS_TILE:.+]] = tensor.extract_slice %[[ARG1]] +// CHECK-SAME: [0, %[[IV1]]] [%[[K]], %[[TS_X]]] [1, 1] +// CHECK: %[[INIT_TILE:.+]] = tensor.extract_slice %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [10, %[[TS_X]]] [1, 1] +// CHECK: %[[GEMM_TILE:.+]] = linalg.matmul +// CHECK-SAME: ins(%[[LHS_TILE]], %[[RHS_TILE]] : tensor<10x?xf32>, tensor +// CHECK-SAME: outs(%[[INIT_TILE]] : +// CHECK: scf.forall.in_parallel { +// CHECK: tensor.parallel_insert_slice %[[GEMM_TILE]] into %[[INIT]] +// CHECK-SAME: [%[[IV0]], %[[IV1]]] [10, %[[TS_X]]] [1, 1] +// CHECK: mapping = [#gpu.block, #gpu.block] +// CHECK: return %[[RESULT]] diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp index 45d6ae3820159..41446ba176a3b 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -21,6 +21,7 @@ #include "mlir/IR/Dominance.h" #include "mlir/IR/OpImplementation.h" #include "mlir/Interfaces/TilingInterface.h" +#include "llvm/ADT/SmallVectorExtras.h" #define GET_OP_CLASSES #include "TestTilingInterfaceTransformOps.h.inc" @@ -54,12 +55,11 @@ static llvm::SmallDenseSet collectTiledAndFusedOps(Operation *op) { /// Apply a tile and fuse transformation to all payload ops and store both the /// tiled operation as well as the created tile loops. template -static LogicalResult -applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, - Range &&payloadOps, unsigned numLoops, - ArrayRef tileSizes, - ArrayRef interchange, bool useForall, - TransformResults &transformResults) { +static LogicalResult applyTileAndFuseToAll( + RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, + unsigned numLoops, ArrayRef tileSizes, + ArrayRef interchange, ArrayRef divisibilityHint, + bool useForall, TransformResults &transformResults) { SmallVector tiledOps; SmallVector> loopOps(numLoops); @@ -85,6 +85,7 @@ applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, if (useForall) { tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); } + tilingOptions.setDivisibilityHint(divisibilityHint); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.setTilingOptions(tilingOptions); @@ -151,13 +152,16 @@ transform::TestFuseAndYieldOp::apply(TransformRewriter &rewriter, SmallVector tileInterchange = extractFromIntegerArrayAttr(getTileInterchange()); + SmallVector divisibilityHint( + getDivisibilityHint().getAsValueRange()); + SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); LogicalResult result = applyTileAndFuseToAll( rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, - tileInterchange, getUseForall(), transformResults); + tileInterchange, divisibilityHint, getUseForall(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } @@ -237,7 +241,8 @@ template static LogicalResult applyTileToAll(RewriterBase &rewriter, Operation *transformOp, Range &&payloadOps, ArrayRef tileSizes, - ArrayRef interchange, std::optional mapping, + ArrayRef interchange, ArrayRef divisibilityHint, + std::optional mapping, TransformResults &transformResults) { SmallVector tiledOps; SmallVector loopOps; @@ -251,6 +256,7 @@ applyTileToAll(RewriterBase &rewriter, Operation *transformOp, if (mapping) { tilingOptions.setMapping(mapping.value().getValue()); } + tilingOptions.setDivisibilityHint(divisibilityHint); tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); rewriter.setInsertionPoint(target); @@ -287,9 +293,12 @@ transform::TestTileUsingForallOp::apply(TransformRewriter &rewriter, SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); - LogicalResult result = - applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), - tileSizesOfr, interchange, getMapping(), transformResults); + SmallVector divisibilityHint( + getDivisibilityHint().getAsValueRange()); + + LogicalResult result = applyTileToAll( + rewriter, getOperation(), state.getPayloadOps(getTarget()), tileSizesOfr, + interchange, divisibilityHint, getMapping(), transformResults); return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() : DiagnosedSilenceableFailure::success(); } @@ -363,11 +372,15 @@ transform::TestFuseUsingForallOp::apply(TransformRewriter &rewriter, SmallVector tileInterchange = extractFromIntegerArrayAttr(getInterchange()); + SmallVector divisibilityHint( + getDivisibilityHint().getAsValueRange()); + scf::SCFTilingOptions tilingOptions; tilingOptions.interchangeVector = tileInterchange; SmallVector tileSizesOfr = getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); tilingOptions = tilingOptions.setTileSizes(tileSizesOfr); + tilingOptions = tilingOptions.setDivisibilityHint(divisibilityHint); tilingOptions.setLoopType(scf::SCFTilingOptions::LoopType::ForallOp); scf::SCFTileAndFuseOptions tileAndFuseOptions; tileAndFuseOptions.tilingOptions = tilingOptions; diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td index 98f7145c99cb1..e7d8732808d78 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -38,12 +38,14 @@ def TestFuseAndYieldOp : Op:$tile_sizes, DefaultValuedAttr:$tile_interchange, + DefaultValuedOptionalAttr:$divisibility_hint, DefaultValuedAttr:$use_forall); let results = (outs TransformHandleTypeInterface:$transfomed, Variadic:$loops); let assemblyFormat = [{ $target ($tile_sizes^)? (`interchange` $tile_interchange^)? + (`divisibility_hint` `=` $divisibility_hint^)? (`use_forall` $use_forall^)? attr-dict `:` functional-type(operands, results) }]; @@ -91,12 +93,14 @@ def TestTileUsingForallOp : Op:$tile_sizes, DefaultValuedOptionalAttr:$interchange, + DefaultValuedOptionalAttr:$divisibility_hint, OptionalAttr:$mapping); let results = (outs TransformHandleTypeInterface:$tiled_op, Variadic:$loops); let assemblyFormat = [{ $target ($tile_sizes^)? (`interchange` `=` $interchange^)? + (`divisibility_hint` `=` $divisibility_hint^)? (`mapping` `=` $mapping^)? attr-dict `:` functional-type(operands, results) }]; @@ -114,12 +118,14 @@ def TestFuseUsingForallOp : Op:$tile_sizes, DefaultValuedOptionalAttr:$interchange, + DefaultValuedOptionalAttr:$divisibility_hint, OptionalAttr:$mapping); let results = (outs TransformHandleTypeInterface:$tiled_ops, Variadic:$loops); let assemblyFormat = [{ $root_op ($tile_sizes^)? (`interchange` $interchange^)? + (`divisibility_hint` `=` $divisibility_hint^)? (`mapping` `=` $mapping^)? attr-dict `:` functional-type(operands, results) }];