diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td index 1e48a5e3a20ee..ae2ae85e8479b 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td @@ -34,7 +34,7 @@ class Linalg_RelayoutOp traits = []> : Op, DestinationStyleOpInterface, LinalgRelayoutOpInterface, - ConditionallySpeculatable, NoMemoryEffect, + ConditionallySpeculatable, DeclareOpInterfaceMethods, DeclareOpInterfaceMethods, TypesMatchWith<"result type matches type of dest", "dest", "result", @@ -43,10 +43,10 @@ class Linalg_RelayoutOp traits = []> : code commonExtraClassDeclaration = [{ size_t getSourceRank() { return getSourceType().getRank(); }; size_t getDestRank() { return getDestType().getRank(); }; - RankedTensorType getSourceType() { - return ::llvm::cast(getSource().getType()); }; - RankedTensorType getDestType() { - return ::llvm::cast(getDest().getType()); }; + ShapedType getSourceType() { + return ::llvm::cast(getSource().getType()); }; + ShapedType getDestType() { + return ::llvm::cast(getDest().getType()); }; MutableOperandRange getDpsInitsMutable() { return getDestMutable(); } @@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // Note: Only tiled dimensions can be padded. ``` }]; - let arguments = (ins AnyRankedTensor:$source, - AnyRankedTensor:$dest, + let arguments = (ins AnyShaped:$source, + AnyShaped:$dest, Optional:$padding_value, DefaultValuedOptionalAttr:$outer_dims_perm, DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, DenseI64ArrayAttr:$static_inner_tiles); - let results = (outs AnyRankedTensor:$result); + let results = (outs AnyShaped:$result); let assemblyFormat = [{ $source (`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)? @@ -190,7 +190,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [ // Method to get the `RankedTensorType` of the result based on the inner // tiles, position of the inner tiles (innerDimsPos) and interchange vector // of outer loops (outerDimsPerm). - static RankedTensorType inferPackedType(RankedTensorType sourceType, + static RankedTensorType inferPackedTensorType(RankedTensorType sourceType, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + + // Method to get the `MemRefType` of the result based on the inner + // tiles, position of the inner tiles (innerDimsPos) and interchange vector + // of outer loops (outerDimsPerm). + static MemRefType inferPackedMemRefType(MemRefType sourceType, + ArrayRef innerTileSizes, ArrayRef innerDimsPos, + ArrayRef outerDimsPerm = {}); + + // Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic. + static SmallVector inferPackedShape(ArrayRef inputShape, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm = {}); @@ -279,13 +291,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> { : tensor<8x16x8x32xf32> -> tensor<128x256xf32> ``` }]; - let arguments = (ins AnyRankedTensor:$source, - AnyRankedTensor:$dest, + let arguments = (ins AnyShaped:$source, + AnyShaped:$dest, DefaultValuedOptionalAttr:$outer_dims_perm, DenseI64ArrayAttr:$inner_dims_pos, Variadic:$inner_tiles, DenseI64ArrayAttr:$static_inner_tiles); - let results = (outs AnyRankedTensor:$result); + let results = (outs AnyShaped:$result); let assemblyFormat = [{ $source (`outer_dims_perm` `=` $outer_dims_perm^)? diff --git a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp index 6119097456d1f..ad9621257f5df 100644 --- a/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp +++ b/mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp @@ -77,10 +77,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern { if constexpr (llvm::is_one_of::value) { - if (isa(originalOperand.getType())) { - // TODO: Support memref type in variable operands + if (isa(originalOperand.getType())) return rewriter.notifyMatchFailure(op, "memref is not supported yet"); - } } convertedOperands.push_back(convertedOperand); } diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index d6b093c5fb86b..da8b13355a17e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -46,6 +46,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringSet.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/LogicalResult.h" #include "llvm/Support/MathExtras.h" @@ -804,7 +805,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern { rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)})); } - RankedTensorType srcPadType = srcPadOp.getSourceType(); + ShapedType srcPadType = srcPadOp.getSourceType(); SmallVector newSizes; for (int i = 0, e = srcPadType.getRank(); i < e; ++i) { if (srcPadType.isDynamicDim(i)) { @@ -4427,15 +4428,28 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); }; + // Verify that the source and destination are ranked types. + if (!packOrUnPack.getSourceType().hasRank() || + !packOrUnPack.getDestType().hasRank()) { + return op->emitError("expected both source and destination to have rank"); + } + // Verify tiles. Do not allow zero tiles. SmallVector mixedTiles = packOrUnPack.getMixedTiles(); if (hasZeros(mixedTiles)) return op->emitError("invalid zero tile factor"); + // Verify that the Operation does not have mixed tensor/buffer semantics. + if (!packOrUnPack.hasPureBufferSemantics() && + !packOrUnPack.hasPureTensorSemantics()) { + return op->emitError("mixing tensor and buffer semantics is not allowed"); + } + // Verify inner_dims_pos and outer_dims_perm. - RankedTensorType unpackedType = (std::is_same::value) - ? packOrUnPack.getSourceType() - : packOrUnPack.getDestType(); + ShapedType unpackedType = (std::is_same::value) + ? packOrUnPack.getSourceType() + : packOrUnPack.getDestType(); + size_t unpackedRank = unpackedType.getRank(); ArrayRef innerDimsPos = packOrUnPack.getInnerDimsPos(); ArrayRef outerDimPerm = packOrUnPack.getOuterDimsPerm(); @@ -4472,12 +4486,23 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { // Verify result shape is greater than the minimum expected // by the pack operation, and that the output shape // represents full tiles. - RankedTensorType expectedPackedType = PackOp::inferPackedType( - unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); - if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { + SmallVector expectedPackedShape = PackOp::inferPackedShape( + unpackedType.getShape(), packOrUnPack.getStaticTiles(), + packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm()); + + if (!areAllInBound(expectedPackedShape, packedType.getShape())) { + auto elementType = unpackedType.getElementType(); + Type expectedType, actualType; + if (packOrUnPack.hasPureTensorSemantics()) { + expectedType = RankedTensorType::get(expectedPackedShape, elementType); + actualType = RankedTensorType::get(packedType.getShape(), elementType); + } else { + expectedType = MemRefType::get(expectedPackedShape, elementType); + actualType = MemRefType::get(packedType.getShape(), elementType); + } return op->emitError("the shape of output is not large enough to hold the " "packed data. Expected at least ") - << expectedPackedType << ", got " << packedType; + << expectedType << ", got " << actualType; } if (!llvm::all_of( llvm::zip(packedType.getShape().take_back(mixedTiles.size()), @@ -4681,13 +4706,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef ofrs) { return result; } -/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of -/// the packed type. Having a shared helper helps implement these two methods in -/// a way that ensures that they agree on which dimensions are dynamic. -static SmallVector getPackOpResultTypeShape( - ArrayRef sourceShape, ArrayRef innerTileSizes, - ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { - SmallVector resultShape = llvm::to_vector(sourceShape); +SmallVector PackOp::inferPackedShape(ArrayRef inputShape, + ArrayRef innerTileSizes, + ArrayRef innerDimsPos, + ArrayRef outerDimsPerm) { + SmallVector resultShape = llvm::to_vector(inputShape); for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) { if (ShapedType::isDynamic(resultShape[tiledDim.value()])) continue; @@ -4727,9 +4750,9 @@ SmallVector PackOp::getResultShape( resultDims.append(innerTileSizes.begin(), innerTileSizes.end()); SmallVector resultTypeShape = - getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims), - asShapeWithAnyValueAsDynamic(innerTileSizes), - innerDimsPos, outerDimsPerm); + inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims), + asShapeWithAnyValueAsDynamic(innerTileSizes), + innerDimsPos, outerDimsPerm); // Fix-up `resultDims` to ensure that they are Value's if and only if the // result type shape says it's a dynamic dim. This is needed as callers may @@ -4747,13 +4770,21 @@ SmallVector PackOp::getResultShape( /// Get the expected packed type based on source type, tile factors, position of /// the inner tiles and permutation of the outer tiled loop. -RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType, +RankedTensorType PackOp::inferPackedTensorType( + RankedTensorType sourceType, ArrayRef innerTileSizes, + ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { + SmallVector resultShape = inferPackedShape( + sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); + return RankedTensorType::get(resultShape, sourceType.getElementType()); +} + +MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType, ArrayRef innerTileSizes, ArrayRef innerDimsPos, ArrayRef outerDimsPerm) { - SmallVector resultShape = getPackOpResultTypeShape( + SmallVector resultShape = inferPackedShape( sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); - return RankedTensorType::get(resultShape, sourceType.getElementType()); + return MemRefType::get(resultShape, sourceType.getElementType()); } Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, @@ -4802,6 +4833,45 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, getPaddingValue(), metadata.outerDimsPerm); } +template +static void getPackUnPackEffectsImpl( + OpTy op, SmallVectorImpl> + &effects) { + // No memory effects for pure tensor semantics + if (op.hasPureTensorSemantics()) + return; + + for (OpOperand &opOperand : op.getOperation()->getOpOperands()) { + if (!llvm::isa(opOperand.get().getType())) + continue; + + if (&opOperand == &op.getSourceMutable()) { + effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } else if (&opOperand == &op.getDestMutable()) { + effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0, + /*effectOnFullRegion=*/true, + SideEffects::DefaultResource::get()); + } + } +} + +void PackOp::getEffects( + SmallVectorImpl> + &effects) { + getPackUnPackEffectsImpl(*this, effects); +} + +void UnPackOp::getEffects( + SmallVectorImpl> + &effects) { + getPackUnPackEffectsImpl(*this, effects); +} + /// Returns true if the tiles and the tiled dims are constant. template bool areTilesAndTiledDimsAllConstant(OpTy op) { @@ -4821,6 +4891,9 @@ bool areTilesAndTiledDimsAllConstant(OpTy op) { } Speculation::Speculatability PackOp::getSpeculatability() { + if (!hasPureTensorSemantics()) + return Speculation::NotSpeculatable; + if (getPaddingValue()) return Speculation::Speculatable; @@ -4933,7 +5006,8 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { // Insert tensor.cast ops if static shape inference is available.. SmallVector srcShape, destShape; - if (inferStaticShape(packOp, srcShape, destShape)) { + if (inferStaticShape(packOp, srcShape, destShape) && + packOp.hasPureTensorSemantics()) { Location loc = packOp.getLoc(); Value source = packOp.getSource(); if (srcShape != packOp.getSourceType().getShape()) { @@ -4942,7 +5016,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { rewriter.create(loc, newSrcType, packOp.getSource()); } Value dest = packOp.getDest(); - RankedTensorType originalResultType = packOp.getDestType(); + ShapedType originalResultType = packOp.getDestType(); bool needUpdateDestType = (destShape != originalResultType.getShape()); if (needUpdateDestType) { auto newDestType = packOp.getDestType().clone(destShape); @@ -4961,6 +5035,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { rewriter.create(loc, originalResultType, packOp); rewriter.replaceAllUsesExcept(packOp, castOp, castOp); } + return success(); } @@ -4968,8 +5043,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { } template -static bool isLikePadUnPad(PackOrUnpackOp packOp, - RankedTensorType packedTensorType) { +static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) { static_assert(std::is_same::value || std::is_same::value, "Function meant for pack/unpack"); @@ -5002,17 +5076,20 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, bool PackOp::isLikePad() { auto packedTensorType = - llvm::cast((*this)->getResultTypes().front()); + llvm::dyn_cast((*this)->getResultTypes().front()); return isLikePadUnPad(*this, packedTensorType); } OpFoldResult PackOp::fold(FoldAdaptor adaptor) { + if (!hasPureTensorSemantics()) + return {}; + std::optional paddingValue; if (auto pad = adaptor.getPaddingValue()) paddingValue = pad; if (OpFoldResult reshapedSource = reshapeConstantSource( llvm::dyn_cast_if_present(adaptor.getSource()), - getDestType(), paddingValue)) + cast(getDestType()), paddingValue)) return reshapedSource; return {}; } @@ -5039,6 +5116,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern { if (!tensor::hasFoldableTensorCastOperand(op)) return failure(); + // TODO: Support Memref PackOp. Temporarily return failure. + if (!op.hasPureTensorSemantics()) + return failure(); + SmallVector newResultTypes(op->getResultTypes()); SmallVector newOperands = tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes); @@ -5119,6 +5200,9 @@ LogicalResult UnPackOp::verify() { } Speculation::Speculatability UnPackOp::getSpeculatability() { + if (!hasPureTensorSemantics()) + return Speculation::NotSpeculatable; + // See PackOp::getSpeculatability. if (!areTilesAndTiledDimsAllConstant(*this)) return Speculation::NotSpeculatable; @@ -5296,14 +5380,16 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, } bool UnPackOp::isLikeUnPad() { - RankedTensorType packedTensorType = getSourceType(); + ShapedType packedTensorType = getSourceType(); return isLikePadUnPad(*this, packedTensorType); } OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) { + if (!hasPureTensorSemantics()) + return {}; if (OpFoldResult reshapedSource = reshapeConstantSource( llvm::dyn_cast_if_present(adaptor.getSource()), - getResult().getType())) + cast(getResult().getType()))) return reshapedSource; return {}; } @@ -5330,6 +5416,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern { if (!tensor::hasFoldableTensorCastOperand(op)) return failure(); + // TODO: Support Memref UnPackOp. Temporarily return failure. + if (!op.hasPureTensorSemantics()) + return failure(); + SmallVector newResultTypes(op->getResultTypes()); SmallVector newOperands = tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes); diff --git a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp index 81842e4bea631..cdd9d3da9bcf8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/BlockPackMatmul.cpp @@ -91,6 +91,10 @@ transposePackedMatmul(RewriterBase &rewriter, linalg::LinalgOp linalgOp, linalg::PackOp packOp, AffineMap operandMap, ArrayRef blocksStartDimPos, bool transposeOuterBlocks, bool transposeInnerBlocks) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + assert(operandMap.getNumDims() >= 4 && "expected at least 4D prepacked matmul"); assert(blocksStartDimPos.size() >= 2 && diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp index 9f5000b70b6f6..5f38f5a84ac64 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp @@ -461,6 +461,9 @@ struct BubbleUpPackOpThroughGenericOpPattern LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto genericOp = bubbleUpPackOpThroughGenericOp(rewriter, packOp, controlFn); if (failed(genericOp)) @@ -483,6 +486,9 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern { LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto padOp = packOp.getSource().getDefiningOp(); if (!padOp) return failure(); @@ -651,6 +657,9 @@ static LogicalResult bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp, linalg::PackOp packOp, PatternRewriter &rewriter) { + if (!packOp.hasPureTensorSemantics()) + return failure(); + SmallVector innerTileSizes = packOp.getStaticTiles(); ArrayRef innerDimsPos = packOp.getInnerDimsPos(); ArrayRef outerDimsPerm = packOp.getOuterDimsPerm(); @@ -757,6 +766,9 @@ static LogicalResult bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, linalg::PackOp packOp, PatternRewriter &rewriter) { + if (!packOp.hasPureTensorSemantics()) + return failure(); + // Outer dimensions permutation is not supported currently. // TODO: Handle outer_dims_perm variants. ArrayRef outerDimsPerm = packOp.getOuterDimsPerm(); @@ -808,7 +820,7 @@ bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp, // If reassociation is not possible, then reordering cannot happen. // This can be caused by pack padding affecting previously expanded // dimensions or packing extending dimensions. - RankedTensorType newPackType = linalg::PackOp::inferPackedType( + RankedTensorType newPackType = linalg::PackOp::inferPackedTensorType( expandOp.getSrcType(), packOp.getStaticInnerTiles(), projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector{}); auto reassocExpand = @@ -840,6 +852,9 @@ class BubbleUpPackOpThroughReshapeOp final LogicalResult matchAndRewrite(linalg::PackOp packOp, PatternRewriter &rewriter) const override { + if (!packOp.hasPureTensorSemantics()) + return failure(); + Operation *srcOp = packOp.getSource().getDefiningOp(); // Currently only support when the pack op is the only user. if (!srcOp || !(srcOp->getNumResults() == 1) || @@ -893,6 +908,9 @@ class BubbleUpPackOpThroughReshapeOp final static LogicalResult pushDownUnPackOpThroughExpandShape( linalg::UnPackOp unPackOp, tensor::ExpandShapeOp expandOp, PatternRewriter &rewriter, ControlPropagationFn controlFn) { + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + // User controlled propagation function. if (!controlFn(&expandOp.getSrcMutable())) return failure(); @@ -943,7 +961,7 @@ static LogicalResult pushDownUnPackOpThroughExpandShape( nextPos += 1; } - RankedTensorType newExpandType = linalg::PackOp::inferPackedType( + RankedTensorType newExpandType = linalg::PackOp::inferPackedTensorType( expandTy, innerTileSizes, projectedInnerDimsPos, newOuterDimsPerm); auto newExpandOp = rewriter.create( expandOp.getLoc(), newExpandType, unPackOp.getSource(), @@ -970,6 +988,9 @@ class PushDownUnPackOpThroughReshapeOp final LogicalResult matchAndRewrite(linalg::UnPackOp unPackOp, PatternRewriter &rewriter) const override { + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + Value result = unPackOp.getResult(); // Currently only support unpack op with the single user. if (!result.hasOneUse()) { @@ -1146,11 +1167,16 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern { LogicalResult matchAndRewrite(tensor::PadOp padOp, PatternRewriter &rewriter) const override { + linalg::UnPackOp unpackOp = padOp.getSource().getDefiningOp(); + if (!unpackOp) return failure(); + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + if (!controlFn(&padOp.getSourceMutable())) return failure(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp index 0984b6988b93b..59e4b2ff634c2 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/PackAndUnpackPatterns.cpp @@ -111,7 +111,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern { if (packOp.getPaddingValue()) return rewriter.notifyMatchFailure(packOp, "expects no padding value"); - RankedTensorType sourceType = packOp.getSourceType(); + ShapedType sourceType = packOp.getSourceType(); if (failed(isPackOnInnerMostDim(rewriter, packOp)) && failed(isPackOn1D(rewriter, packOp, sourceType.getShape(), packOp.getStaticTiles())) && @@ -119,7 +119,7 @@ struct SimplifyPackToExpandShape : public OpRewritePattern { return failure(); } - RankedTensorType destType = packOp.getDestType(); + ShapedType destType = packOp.getDestType(); auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) @@ -157,8 +157,8 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { "expects outer_dims_perm is empty or an identity permutation"); } - RankedTensorType sourceType = unpackOp.getSourceType(); - RankedTensorType destType = unpackOp.getDestType(); + ShapedType sourceType = unpackOp.getSourceType(); + ShapedType destType = unpackOp.getDestType(); if (!sourceType.hasStaticShape() || !destType.hasStaticShape()) return rewriter.notifyMatchFailure(unpackOp, "expects static shapes"); @@ -171,25 +171,27 @@ struct SimplifyUnPackToCollapseShape : public OpRewritePattern { return success(); } - LogicalResult matchAndRewrite(UnPackOp unpackOp, + LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { - RankedTensorType destType = unpackOp.getDestType(); - if (failed(isUnpackOnInnerMostDim(rewriter, unpackOp)) && - failed(isPackOn1D(rewriter, unpackOp, destType.getShape(), - unpackOp.getStaticTiles())) && - !unpackOp.isLikeUnPad()) { + if (!unPackOp.hasPureTensorSemantics()) + return failure(); + ShapedType destType = unPackOp.getDestType(); + if (failed(isUnpackOnInnerMostDim(rewriter, unPackOp)) && + failed(isPackOn1D(rewriter, unPackOp, destType.getShape(), + unPackOp.getStaticTiles())) && + !unPackOp.isLikeUnPad()) { return failure(); } - RankedTensorType sourceType = unpackOp.getSourceType(); + ShapedType sourceType = unPackOp.getSourceType(); auto reassociation = getReassociationIndicesForReshape(sourceType, destType); if (!reassociation) return failure(); Value collapsed = insertCollapse( - rewriter, unpackOp.getLoc(), unpackOp.getSource(), destType, + rewriter, unPackOp.getLoc(), unPackOp.getSource(), destType, getReassociationIndicesAttribute(rewriter, *reassociation)); - rewriter.replaceOp(unpackOp, collapsed); + rewriter.replaceOp(unPackOp, collapsed); return success(); } }; @@ -426,6 +428,8 @@ struct FoldConsumerUnPackWithProducerLinalgTransposeOp LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { + if (!unPackOp.hasPureTensorSemantics()) + return failure(); auto linalgOp = unPackOp.getSource().getDefiningOp(); if (!linalgOp) return failure(); @@ -507,6 +511,8 @@ struct FoldEmptyTensorWithUnPackOp : public OpRewritePattern { LogicalResult matchAndRewrite(UnPackOp unPackOp, PatternRewriter &rewriter) const override { + if (!unPackOp.hasPureTensorSemantics()) + return failure(); // Check for tensor.empty source. auto emptyOp = unPackOp.getSource().getDefiningOp(); if (!emptyOp) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index dcd50cc44f81b..105fdc63c4be8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -219,6 +219,9 @@ struct PackedOperandsDimList { FailureOr linalg::lowerPack(RewriterBase &rewriter, linalg::PackOp packOp, bool lowerPadLikeWithInsertSlice) { + if (!packOp.hasPureTensorSemantics()) + return failure(); + // 1. Filter out NYI cases. auto packedTensorType = cast(packOp->getResultTypes().front()); @@ -265,6 +268,7 @@ FailureOr linalg::lowerPack(RewriterBase &rewriter, highs[pos] = affine::makeComposedFoldedAffineApply( rewriter, loc, map, {outerSize, origSize, innerSize}); } + RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType( RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape), packingMetadata.reassociations); @@ -359,7 +363,7 @@ linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp, OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unPackOp); - RankedTensorType packedTensorType = unPackOp.getSourceType(); + auto packedTensorType = cast(unPackOp.getSourceType()); int64_t packedRank = packedTensorType.getRank(); OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1); @@ -1026,6 +1030,7 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( static Value getPackOpSourceOrPaddedSource(OpBuilder &builder, linalg::PackOp packOp) { Value input = packOp.getSource(); + if (!packOp.getPaddingValue()) { return input; } @@ -1142,6 +1147,9 @@ getPackUnpackRankReducedPerm(ArrayRef shape, LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( linalg::PackOp packOp, PatternRewriter &rewriter) const { + if (!packOp.hasPureTensorSemantics()) + return failure(); + // TODO: support the case that outer dimensions are not all 1s. A // tensor.expand_shape will be generated in this case. if (llvm::any_of(packOp.getAllOuterDims(), @@ -1243,6 +1251,9 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite( LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite( linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const { + if (!unpackOp.hasPureTensorSemantics()) + return failure(); + int64_t srcRank = unpackOp.getSourceRank(); int64_t destRank = unpackOp.getDestRank(); ArrayRef srcShape = unpackOp.getSourceType().getShape(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index 8c8b1b85ef5a3..28a3707f6d35c 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -1664,12 +1664,11 @@ static LogicalResult vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp, ArrayRef inputVectorSizes, SmallVectorImpl &newResults) { - // TODO: Introduce a parent class that will handle the insertion point update. OpBuilder::InsertionGuard g(rewriter); rewriter.setInsertionPoint(unpackOp); - RankedTensorType unpackTensorType = unpackOp.getSourceType(); + auto unpackTensorType = cast(unpackOp.getSourceType()); ArrayRef innerDimPos = unpackOp.getInnerDimsPos(); ArrayRef innerTiles = unpackOp.getStaticInnerTiles(); @@ -1889,6 +1888,9 @@ vectorizeDynamicLinalgOpPrecondition(linalg::LinalgOp op, static LogicalResult vectorizeUnPackOpPrecondition(linalg::UnPackOp unpackOp, ArrayRef inputVectorSizes) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!unpackOp.hasPureTensorSemantics()) + return failure(); if (llvm::any_of(unpackOp.getInnerTiles(), [](OpFoldResult res) { return !getConstantIntValue(res).has_value(); @@ -2134,6 +2136,10 @@ static LogicalResult vectorizeLinalgOpPrecondition( static LogicalResult vectorizePackOpPrecondition(linalg::PackOp packOp, ArrayRef inputVectorSizes) { + // TODO: Support Memref PackOp. Temporarily return failure. + if (!packOp.hasPureTensorSemantics()) + return failure(); + auto padValue = packOp.getPaddingValue(); Attribute cstAttr; if (padValue && !matchPattern(padValue, m_Constant(&cstAttr))) { diff --git a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp index 0336423c57b1d..86a1fb12f2b26 100644 --- a/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp +++ b/mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp @@ -315,11 +315,11 @@ SmallVector SliceFromCollapseHelper::getExtractSliceParams( // have proven that these are not sliced. In this case we just take // the full extent of each dimension in the reassociation list. if (linearizedDimensions[it.index()]) { - llvm::append_range( - offsetsSizesAndStrides, - llvm::map_range(it.value(), [&](int64_t idx) -> Range { - return {zeroAttr, collapseShapeInputShape[idx], oneAttr}; - })); + llvm::append_range(offsetsSizesAndStrides, + llvm::map_range(it.value(), [&](int64_t idx) -> Range { + return {zeroAttr, collapseShapeInputShape[idx], + oneAttr}; + })); continue; } diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir index 86cb8f58abe02..3598073af11a2 100644 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ b/mlir/test/Dialect/Linalg/canonicalize.mlir @@ -1722,6 +1722,31 @@ func.func @infer_and_fold_pack_unpack_same_tiles(%t: tensor<10x20x4x4xf32>) -> t // ----- +func.func @infer_and_fold_pack_unpack_same_tiles_memref(%t: memref<10x20x4x4xf32>) -> memref<10x20x4x4xf32> { + %c40 = arith.constant 40 : index + %c80 = arith.constant 80 : index + %buf_unpacked = memref.alloc() : memref<40x80xf32> + %unpacked = linalg.unpack %t inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_unpacked : memref<10x20x4x4xf32> -> memref<40x80xf32> + %buf_packed = memref.alloc() : memref<10x20x4x4xf32> + %packed = linalg.pack %unpacked inner_dims_pos = [0, 1] inner_tiles = [4, 4] into %buf_packed : memref<40x80xf32> -> memref<10x20x4x4xf32> + return %packed : memref<10x20x4x4xf32> +} +// CHECK-LABEL: func.func @infer_and_fold_pack_unpack_same_tiles_memref +// CHECK-SAME: %[[SRC:[0-9a-zA-Z]+]] +// CHECK: return %[[SRC]] + +// ----- + +func.func @fold_pack_unpack_memref(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) -> memref<2x3xf32> { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c3 = arith.constant 3 : index + %pack_dest = memref.alloc() : memref<2x3x1x1xf32> + %pack = linalg.pack %arg0 inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %pack_dest : memref<2x3xf32> -> memref<2x3x1x1xf32> + %unpack = linalg.unpack %pack inner_dims_pos = [0, 1] inner_tiles = [1, 1] into %arg1 : memref<2x3x1x1xf32> -> memref<2x3xf32> + return %arg1 : memref<2x3xf32> +} + // CHECK-LABEL: func.func @pack_dont_drop_attributes( // CHECK: linalg.pack {{.*}} {test_attr} func.func @pack_dont_drop_attributes(%arg0: tensor, %arg1: tensor<128x?x100x16x1xf16>) -> tensor<128x?x100x16x1xf16> { @@ -1730,6 +1755,7 @@ func.func @pack_dont_drop_attributes(%arg0: tensor, %arg1: tensor<128 %pack = linalg.pack %arg0 padding_value(%cst : f16) outer_dims_perm = [0, 1, 2] inner_dims_pos = [1, 2] inner_tiles = [16, 1] into %arg1 {test_attr} : tensor -> tensor<128x?x100x16x1xf16> return %pack : tensor<128x?x100x16x1xf16> } + // ----- //===----------------------------------------------------------------------===// @@ -1847,3 +1873,49 @@ func.func @no_fold_extract_slice_into_unpack_non_zero_offset( // CHECK-SAME: into %[[DEST]] // CHECK: %[[SLICE:.+]] = tensor.extract_slice %[[UNPACK]] // CHECK: return %[[SLICE]] + +// ----- + +// CHECK-LABEL: func.func @fold_cast_unpack_dynamic_tile_size( +// CHECK-SAME: %[[SRC:.*]]: tensor<1x1x8x1xi32>, +// CHECK-SAME: %[[DEST:.*]]: tensor<7x?xi32>) -> tensor<7x?xi32> { +// CHECK: %[[RES:.*]] = linalg.unpack %[[SRC]] inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %[[DEST]] {test_attr} : tensor<1x1x8x1xi32> -> tensor<7x?xi32> +// CHECK: return %[[RES]] : tensor<7x?xi32> +func.func @fold_cast_unpack_dynamic_tile_size( + %src: tensor<1x1x8x1xi32>, + %res: tensor<7x?xi32>) -> tensor<7x?xi32> { + + %cast = tensor.cast %src : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> + %c8 = arith.constant 8 : index + %unpack = linalg.unpack %cast + inner_dims_pos = [0, 1] + inner_tiles = [%c8, 1] + into %res {test_attr} : tensor<1x1x?x1xi32> -> tensor<7x?xi32> + return %unpack : tensor<7x?xi32> +} + +// ----- + +// CHECK-LABEL: func.func @fold_pack_unpack_tensor +// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> +// CHECK: return %[[ARG0]] : tensor<2x3xf32> +func.func @fold_pack_unpack_tensor(%x: tensor<2x3xf32>) -> tensor<2x3xf32> { + %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = [] + into %x : tensor<2x3xf32> -> tensor<2x3xf32> + %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = [] + into %x : tensor<2x3xf32> -> tensor<2x3xf32> + return %packed : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: func.func @fold_pack_unpack_memref +// CHECK-SAME: (%[[ARG0:.*]]: memref<2x3xf32>) -> memref<2x3xf32> +// CHECK: return %[[ARG0]] : memref<2x3xf32> +func.func @fold_pack_unpack_memref(%x: memref<2x3xf32>) -> memref<2x3xf32> { + %unpacked = linalg.unpack %x outer_dims_perm = [] inner_dims_pos = [] inner_tiles = [] + into %x : memref<2x3xf32> -> memref<2x3xf32> + %packed = linalg.pack %unpacked outer_dims_perm = [] inner_dims_pos = [] inner_tiles = [] + into %x : memref<2x3xf32> -> memref<2x3xf32> + return %packed : memref<2x3xf32> +} diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 90ceadebbc1fa..c215dfc3666e5 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -1666,3 +1666,39 @@ func.func @unpack_static_inner_tile_size_and_dynamic_output_shape( %0 = linalg.unpack %input inner_dims_pos = [0, 1] inner_tiles = [8, 4] into %output : tensor -> tensor return %0 : tensor } + +// ----- + +func.func @pack_source_dest_type_mismatch_1(%source: tensor<128x256xf32>, %dest: memref<8x16x8x32xf32>) { + // expected-error@+1 {{mixing tensor and buffer semantics is not allowed}} + linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : tensor<128x256xf32> -> memref<8x16x8x32xf32> + return +} + +// ----- + +func.func @pack_source_dest_type_mismatch_2(%source: memref<128x256xf32>, %dest: tensor<8x16x8x32xf32>) { + // expected-error@+1 {{mixing tensor and buffer semantics is not allowed}} + %0 = linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : memref<128x256xf32> -> tensor<8x16x8x32xf32> + return +} + +// ----- + +func.func @unpack_source_dest_type_mismatch_3(%source: tensor<16x8x8x32xf32>, %dest: memref<128x256xf32>) { + // expected-error@+1 {{mixing tensor and buffer semantics is not allowed}} + linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : tensor<16x8x8x32xf32> -> memref<128x256xf32> + return +} + +// ----- + +func.func @unpack_source_dest_type_mismatch_4(%source: memref<16x8x8x32xf32>, %dest: tensor<128x256xf32>) { + // expected-error@+1 {{mixing tensor and buffer semantics is not allowed}} + %0 = linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : memref<16x8x8x32xf32> -> tensor<128x256xf32> + return +} diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index dc556761b09e5..9c5141f56d575 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -706,3 +706,29 @@ func.func @conv2d_channel_first_q_promote(%img: tensor<100x3x224x224xi8>, %filt: // CHECK-LABEL: func @conv2d_channel_first_q_promote( // CHECK: %[[arg0:[a-zA-z0-9]*]]: tensor<100x3x224x224xi8>, %[[arg1:[a-zA-z0-9]*]]: tensor<64x3x5x5xi8>, %[[arg2:[a-zA-z0-9]*]]: i8, %[[arg3:[a-zA-z0-9]*]]: i8) // CHECK: linalg.conv_2d_nchw_fchw_q {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%[[arg0]], %[[arg1]], %[[arg2]], %[[arg3]] : tensor<100x3x224x224xi8>, tensor<64x3x5x5xi8>, i8, i8) outs(%{{.*}} : tensor<100x64x220x220xi32>) -> tensor<100x64x220x220xi32> + +// ----- + +func.func @pack_memref(%source: memref<128x256xf32>, %dest: memref<8x16x8x32xf32>) { + linalg.pack %source outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : memref<128x256xf32> -> memref<8x16x8x32xf32> + return +} + +// CHECK-label: func @pack_memref( +// CHECK: %[[source:[a-zA-z0-9]*]]: memref<128x256xf32>, %[[dest:[a-zA-z0-9]*]]: memref<8x16x8x32xf32>) { +// CHECK: %pack = linalg.pack %arg0 outer_dims_perm = [1, 0] inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<128x256xf32> -> memref<8x16x8x32xf32> +// CHECK: return +// CHECK: } +// ----- + +func.func @unpack_memref(%source: memref<16x8x8x32xf32>, %dest: memref<128x256xf32>) { + linalg.unpack %source inner_dims_pos = [0, 1] inner_tiles = [8, 32] + into %dest : memref<16x8x8x32xf32> -> memref<128x256xf32> + return +} + +// CHECK-label: func @unpack_memref( +// CHECK: %[[source:[a-zA-z0-9]*]]: memref<16x8x8x32xf32>, %[[dest:[a-zA-z0-9]*]]: memref<128x256xf32>) { +// CHECK: %unpack = linalg.unpack %arg0 inner_dims_pos = [0, 1] inner_tiles = [8, 32] into %arg1 : memref<16x8x8x32xf32> -> memref<128x256xf32> +// CHECK: return