diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td index a8d135caa74f0..f3e40aaa29075 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Passes.td @@ -182,6 +182,68 @@ def ResolveShapedTypeResultDimsPass : Pass<"resolve-shaped-type-result-dims"> { ]; } +def ReifyResultShapesPass : Pass<"reify-result-shapes"> { + let summary ="Reifies the results of `tensor::PadOp` and `tensor::ConcatOp`."; + let description = [{ + This pass reifies the shapes of a subset of `ReifyRankedShapedTypeOpInterface` + ops with `tensor` results. + + The pass currently only supports result shape type reification for: + - tensor::PadOp + - tensor::ConcatOp + It addresses a representation gap where implicit op semantics are needed to + infer static result types from dynamic operands. + But it does so by using `ReifyRankedShapedTypeOpInterface` as the source of + truth rather than the op itself. As a consequence, this cannot generalize + today. + + TODO: in the future, we should consider coupling this information with op + "transfer functions" (e.g. `IndexingMapOpInterface`) to provide a source of + truth that can work across result shape inference, canonicalization and op + verifiers. + + The pass replaces the operations with their reified versions, when more + static information can be derived, and inserts casts when results shapes + are updated. + + Example: + ```mlir + #map = affine_map<(d0) -> (-d0 + 256)> + func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) + -> tensor<1x?x64xf32> + { + %0 = affine.apply #map(%arg1) + %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] + : tensor<64x?x64xf32> to tensor<1x?x64xf32> + %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %arg0 : f32 + } : tensor<1x?x64xf32> to tensor<1x?x64xf32> + return %padded : tensor<1x?x64xf32> + } + + // mlir-opt --reify-result-shapes + #map = affine_map<()[s0] -> (-s0 + 256)> + func.func @func(%arg0: f32, %arg1: index, %arg2: tensor<64x?x64xf32>) + -> tensor<1x?x64xf32> + { + %0 = affine.apply #map()[%arg1] + %extracted_slice = tensor.extract_slice %arg2[0, 0, 0] [1, %arg1, 64] [1, 1, 1] + : tensor<64x?x64xf32> to tensor<1x?x64xf32> + %padded = tensor.pad %extracted_slice low[0, 0, 0] high[0, %0, 0] { + ^bb0(%arg3: index, %arg4: index, %arg5: index): + tensor.yield %arg0 : f32 + } : tensor<1x?x64xf32> to tensor<1x256x64xf32> + %cast = tensor.cast %padded : tensor<1x256x64xf32> to tensor<1x?x64xf32> + return %cast : tensor<1x?x64xf32> + } + ``` + }]; + let dependentDialects = [ + "affine::AffineDialect", "memref::MemRefDialect", "tensor::TensorDialect" + ]; +} + def ExpandStridedMetadataPass : Pass<"expand-strided-metadata"> { let summary = "Expand memref operations into easier to analyze constructs"; let description = [{ diff --git a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h index c2b8cb05be922..33e3d94f02b1c 100644 --- a/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/MemRef/Transforms/Transforms.h @@ -23,6 +23,7 @@ class RewritePatternSet; class RewriterBase; class Value; class ValueRange; +class ReifyRankedShapedTypeOpInterface; namespace arith { class WideIntEmulationConverter; @@ -208,7 +209,6 @@ FailureOr replaceWithIndependentOp(RewriterBase &rewriter, memref::AllocaOp allocToAlloca( RewriterBase &rewriter, memref::AllocOp alloc, function_ref filter = nullptr); - } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt index 637f5ec1c9f9b..9049faccadef3 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/MemRef/Transforms/CMakeLists.txt @@ -13,6 +13,7 @@ add_mlir_dialect_library(MLIRMemRefTransforms IndependenceTransforms.cpp MultiBuffer.cpp NormalizeMemRefs.cpp + ReifyResultShapes.cpp ResolveShapedTypeResultDims.cpp RuntimeOpVerification.cpp diff --git a/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp new file mode 100644 index 0000000000000..e6b9e2f7e8213 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/Transforms/ReifyResultShapes.cpp @@ -0,0 +1,159 @@ +//===- ReifyResultShapes.cpp - Reify result shapes ------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This transform reifies result shapes of `ReifyRankedShapedTypeOpInterface` +// operations with ranked `memref` and `tensor` results. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/Transforms/Passes.h" + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Interfaces/DestinationStyleOpInterface.h" +#include "mlir/Interfaces/InferTypeOpInterface.h" +#include "llvm/Support/InterleavedRange.h" + +#define DEBUG_TYPE "reify-result-shapes" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE << "]: ") + +namespace mlir { +namespace memref { +#define GEN_PASS_DEF_REIFYRESULTSHAPESPASS +#include "mlir/Dialect/MemRef/Transforms/Passes.h.inc" +} // namespace memref +} // namespace mlir + +using namespace mlir; + +/// Reifies the results of `op`, potentially replacing `op` with a reified +/// version. Returns `failure` if `mlir::reifyResultShapes` returned failure, +/// otherwise it always succeeds. Users of this transform should always expect +/// it to modify the IR, even when it fails. If any of the result types changes, +/// the transform will insert cast operations to the old type to keep the IR +/// consistent. +static LogicalResult reifyOpResultShapes(RewriterBase &rewriter, + ReifyRankedShapedTypeOpInterface op) { + LLVM_DEBUG({ DBGS() << " reifying op: " << op << "\n"; }); + // Get the reified out shapes. + ReifiedRankedShapedTypeDims reifiedResultShapes; + if (failed(mlir::reifyResultShapes(rewriter, op, reifiedResultShapes)) || + reifiedResultShapes.empty()) { + return op->emitWarning() << "failed to get the reified shapes"; + } + + bool modified = false; + // Compute the new output types. + SmallVector outTypes; + for (const auto &[oldTy, reifiedShape] : + llvm::zip(op->getResultTypes(), reifiedResultShapes)) { + // Skip if it's not a memref or tensor type. + if (!isa(oldTy)) { + outTypes.push_back(oldTy); + continue; + } + + ShapedType shapedTy = dyn_cast(oldTy); + + SmallVector shape = llvm::to_vector(shapedTy.getShape()); + for (auto &&[dim, ofr] : llvm::zip_equal(shape, reifiedShape)) { + std::optional maybeCst = getConstantIntValue(ofr); + // If the reified dim is dynamic set it appropriately. + if (!maybeCst.has_value()) { + dim = ShapedType::kDynamic; + continue; + } + // Set the static dim. + dim = *maybeCst; + } + + // If the shape didn't change continue. + if (shape == shapedTy.getShape()) { + outTypes.push_back(oldTy); + continue; + } + modified = true; + outTypes.push_back(shapedTy.cloneWith(shape, shapedTy.getElementType())); + } + + // Return if we don't need to update. + if (!modified) { + LLVM_DEBUG({ DBGS() << "- op doesn't require update\n"; }); + return success(); + } + + LLVM_DEBUG({ + DBGS() << "- oldTypes: " << llvm::interleaved_array(op->getResultTypes()) + << " \n"; + DBGS() << "- outTypes: " << llvm::interleaved_array(outTypes) << " \n"; + }); + + // We now have outTypes that need to be turned to cast ops. + Location loc = op->getLoc(); + SmallVector newResults; + // TODO: `mlir::reifyResultShapes` and op verifiers may not agree atm. + // This is a confluence problem that will need to be addressed. + // For now, we know PadOp and ConcatOp are fine. + assert((isa(op.getOperation())) && + "incorrect op"); + Operation *newOp = rewriter.clone(*op); + for (auto [reifiedTy, oldRes] : llvm::zip(outTypes, op->getResults())) { + OpResult newRes = newOp->getResult(oldRes.getResultNumber()); + Type oldTy = oldRes.getType(); + // Continue if the type remained invariant or is not shaped. + if (oldTy == reifiedTy || !isa(oldTy)) { + newResults.push_back(newRes); + continue; + } + + // Update the type. + newRes.setType(reifiedTy); + if (isa(reifiedTy)) { + newResults.push_back(rewriter.create(loc, oldTy, newRes)); + } else { + assert(isa(reifiedTy) && "expected a memref type"); + newResults.push_back(rewriter.create(loc, oldTy, newRes)); + } + } + + LLVM_DEBUG({ + DBGS() << "- reified results " << llvm::interleaved_array(newResults) + << "\n"; + }); + rewriter.replaceOp(op, newResults); + return success(); +} + +//===----------------------------------------------------------------------===// +// Pass registration +//===----------------------------------------------------------------------===// + +namespace { +struct ReifyResultShapesPass final + : public memref::impl::ReifyResultShapesPassBase { + void runOnOperation() override; +}; +} // namespace + +void ReifyResultShapesPass::runOnOperation() { + SmallVector ops; + getOperation()->walk([&](ReifyRankedShapedTypeOpInterface op) { + // Handle ops that are not DPS and that do not carry an tied operand shapes. + // For now, limit to tensor::PadOp and tensor::ConcatOp. + if (!isa(op.getOperation())) + return; + ops.push_back(op); + }); + IRRewriter rewriter(&getContext()); + for (ReifyRankedShapedTypeOpInterface op : ops) { + rewriter.setInsertionPoint(op); + (void)reifyOpResultShapes(rewriter, op); + } +} diff --git a/mlir/test/Dialect/Tensor/reify-shapes.mlir b/mlir/test/Dialect/Tensor/reify-shapes.mlir new file mode 100644 index 0000000000000..5569d90f8b731 --- /dev/null +++ b/mlir/test/Dialect/Tensor/reify-shapes.mlir @@ -0,0 +1,31 @@ +// RUN: mlir-opt -reify-result-shapes %s | FileCheck %s + +// The test below checks concat op reification. In the first case, no cast is inserted while on the second a cast gets inserted. +// CHECK-LABEL: func.func @concat_reification +func.func @concat_reification(%arg0: tensor<4x7x3xf32>, %arg1 : tensor<4x4x3xf32>, %arg2: tensor) + -> (tensor<4x11x3xf32>, tensor) { + // CHECK: %[[RES0:.*]] = tensor.concat dim(1) %{{.*}} : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + %1 = tensor.concat dim(1) %arg0, %arg1 : (tensor<4x7x3xf32>, tensor<4x4x3xf32>) -> tensor<4x11x3xf32> + // CHECK: %[[V0:.*]] = tensor.concat dim(2) %{{.*}} : (tensor<4x7x3xf32>, tensor) -> tensor<4x7x?xf32> + // CHECK: %[[RES1:.*]] = tensor.cast %[[V0]] : tensor<4x7x?xf32> to tensor + %2 = tensor.concat dim(2) %arg0, %arg2 : (tensor<4x7x3xf32>, tensor) -> tensor + // CHECK: return %[[RES0]], %[[RES1]] : tensor<4x11x3xf32>, tensor + return %1, %2 : tensor<4x11x3xf32>, tensor +} + +// CHECK-LABEL: func.func @pad_reification +func.func @pad_reification(%cst : f32, %idx : index, %t: tensor<64x?x64xf32>) -> tensor<1x?x64xf32> { + %pad_amt = affine.apply affine_map<(d0) -> (-d0 + 256)>(%idx) + %es = tensor.extract_slice %t[0, 0, 0] [1, %idx, 64] [1, 1, 1] + : tensor<64x?x64xf32> to tensor<1x?x64xf32> + + // CHECK: tensor.pad + // CHECK: : tensor<1x?x64xf32> to tensor<1x256x64xf32> + // CHECK: tensor.cast %{{.*}} : tensor<1x256x64xf32> to tensor<1x?x64xf32> + %padded = tensor.pad %es low[0, 0, 0] high[0, %pad_amt, 0] { + ^bb0(%a: index, %b: index, %c: index): + tensor.yield %cst : f32 + } : tensor<1x?x64xf32> to tensor<1x?x64xf32> + + return %padded : tensor<1x?x64xf32> +}