-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] pack, unpack to take memref inputs #129036
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4d523ad
4f2dbf4
226230c
0c184df
19201c6
b99b920
be6a119
eee8805
c5b3c39
714e4c4
a5d01df
2480616
0421e72
7b92a4e
6dc08ae
cf7be57
4e2f00d
ee7a42a
5b95ee8
8b5ac5a
c955d21
4bedf40
276069d
790e974
820e40b
43a64b9
a3bba60
486c62b
ca889b5
ce910b9
6a501bd
535e796
4cbbb80
17ad838
2aca3fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,6 +46,7 @@ | |
#include "llvm/ADT/SmallVector.h" | ||
#include "llvm/ADT/StringSet.h" | ||
#include "llvm/ADT/TypeSwitch.h" | ||
#include "llvm/Support/Error.h" | ||
#include "llvm/Support/FormatVariadic.h" | ||
#include "llvm/Support/LogicalResult.h" | ||
#include "llvm/Support/MathExtras.h" | ||
|
@@ -804,7 +805,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> { | |
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)})); | ||
} | ||
|
||
RankedTensorType srcPadType = srcPadOp.getSourceType(); | ||
ShapedType srcPadType = srcPadOp.getSourceType(); | ||
SmallVector<OpFoldResult, 4> newSizes; | ||
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) { | ||
if (srcPadType.isDynamicDim(i)) { | ||
|
@@ -4427,15 +4428,28 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { | |
tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); }); | ||
}; | ||
|
||
// Verify that the source and destination are ranked types. | ||
if (!packOrUnPack.getSourceType().hasRank() || | ||
!packOrUnPack.getDestType().hasRank()) { | ||
return op->emitError("expected both source and destination to have rank"); | ||
} | ||
|
||
// Verify tiles. Do not allow zero tiles. | ||
SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles(); | ||
if (hasZeros(mixedTiles)) | ||
return op->emitError("invalid zero tile factor"); | ||
|
||
// Verify that the Operation does not have mixed tensor/buffer semantics. | ||
if (!packOrUnPack.hasPureBufferSemantics() && | ||
!packOrUnPack.hasPureTensorSemantics()) { | ||
return op->emitError("mixing tensor and buffer semantics is not allowed"); | ||
hanhanW marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
// Verify inner_dims_pos and outer_dims_perm. | ||
adam-smnk marked this conversation as resolved.
Show resolved
Hide resolved
|
||
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value) | ||
? packOrUnPack.getSourceType() | ||
: packOrUnPack.getDestType(); | ||
ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value) | ||
? packOrUnPack.getSourceType() | ||
: packOrUnPack.getDestType(); | ||
|
||
size_t unpackedRank = unpackedType.getRank(); | ||
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos(); | ||
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm(); | ||
|
@@ -4472,12 +4486,23 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) { | |
// Verify result shape is greater than the minimum expected | ||
// by the pack operation, and that the output shape | ||
// represents full tiles. | ||
RankedTensorType expectedPackedType = PackOp::inferPackedType( | ||
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm); | ||
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) { | ||
SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape( | ||
unpackedType.getShape(), packOrUnPack.getStaticTiles(), | ||
packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm()); | ||
|
||
if (!areAllInBound(expectedPackedShape, packedType.getShape())) { | ||
auto elementType = unpackedType.getElementType(); | ||
Type expectedType, actualType; | ||
if (packOrUnPack.hasPureTensorSemantics()) { | ||
expectedType = RankedTensorType::get(expectedPackedShape, elementType); | ||
actualType = RankedTensorType::get(packedType.getShape(), elementType); | ||
} else { | ||
expectedType = MemRefType::get(expectedPackedShape, elementType); | ||
actualType = MemRefType::get(packedType.getShape(), elementType); | ||
} | ||
return op->emitError("the shape of output is not large enough to hold the " | ||
"packed data. Expected at least ") | ||
<< expectedPackedType << ", got " << packedType; | ||
<< expectedType << ", got " << actualType; | ||
} | ||
if (!llvm::all_of( | ||
llvm::zip(packedType.getShape().take_back(mixedTiles.size()), | ||
|
@@ -4681,13 +4706,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) { | |
return result; | ||
} | ||
|
||
/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of | ||
/// the packed type. Having a shared helper helps implement these two methods in | ||
/// a way that ensures that they agree on which dimensions are dynamic. | ||
static SmallVector<int64_t> getPackOpResultTypeShape( | ||
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes, | ||
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) { | ||
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape); | ||
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape, | ||
ArrayRef<int64_t> innerTileSizes, | ||
ArrayRef<int64_t> innerDimsPos, | ||
ArrayRef<int64_t> outerDimsPerm) { | ||
SmallVector<int64_t> resultShape = llvm::to_vector(inputShape); | ||
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) { | ||
if (ShapedType::isDynamic(resultShape[tiledDim.value()])) | ||
continue; | ||
|
@@ -4727,9 +4750,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape( | |
resultDims.append(innerTileSizes.begin(), innerTileSizes.end()); | ||
|
||
SmallVector<int64_t> resultTypeShape = | ||
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims), | ||
asShapeWithAnyValueAsDynamic(innerTileSizes), | ||
innerDimsPos, outerDimsPerm); | ||
inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims), | ||
asShapeWithAnyValueAsDynamic(innerTileSizes), | ||
innerDimsPos, outerDimsPerm); | ||
|
||
// Fix-up `resultDims` to ensure that they are Value's if and only if the | ||
// result type shape says it's a dynamic dim. This is needed as callers may | ||
|
@@ -4747,13 +4770,21 @@ SmallVector<OpFoldResult> PackOp::getResultShape( | |
|
||
/// Get the expected packed type based on source type, tile factors, position of | ||
/// the inner tiles and permutation of the outer tiled loop. | ||
RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType, | ||
RankedTensorType PackOp::inferPackedTensorType( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The above comment needs to be removed because they are documented in the declaration. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bump |
||
RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes, | ||
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) { | ||
SmallVector<int64_t> resultShape = inferPackedShape( | ||
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); | ||
return RankedTensorType::get(resultShape, sourceType.getElementType()); | ||
} | ||
|
||
MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType, | ||
ArrayRef<int64_t> innerTileSizes, | ||
ArrayRef<int64_t> innerDimsPos, | ||
ArrayRef<int64_t> outerDimsPerm) { | ||
SmallVector<int64_t> resultShape = getPackOpResultTypeShape( | ||
SmallVector<int64_t> resultShape = inferPackedShape( | ||
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm); | ||
return RankedTensorType::get(resultShape, sourceType.getElementType()); | ||
return MemRefType::get(resultShape, sourceType.getElementType()); | ||
} | ||
|
||
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source, | ||
|
@@ -4802,6 +4833,45 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc, | |
getPaddingValue(), metadata.outerDimsPerm); | ||
} | ||
|
||
template <typename OpTy> | ||
static void getPackUnPackEffectsImpl( | ||
OpTy op, SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> | ||
&effects) { | ||
// No memory effects for pure tensor semantics | ||
if (op.hasPureTensorSemantics()) | ||
return; | ||
|
||
for (OpOperand &opOperand : op.getOperation()->getOpOperands()) { | ||
if (!llvm::isa<MemRefType>(opOperand.get().getType())) | ||
continue; | ||
|
||
if (&opOperand == &op.getSourceMutable()) { | ||
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0, | ||
/*effectOnFullRegion=*/true, | ||
SideEffects::DefaultResource::get()); | ||
} else if (&opOperand == &op.getDestMutable()) { | ||
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0, | ||
/*effectOnFullRegion=*/true, | ||
SideEffects::DefaultResource::get()); | ||
effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0, | ||
/*effectOnFullRegion=*/true, | ||
SideEffects::DefaultResource::get()); | ||
} | ||
} | ||
} | ||
|
||
void PackOp::getEffects( | ||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> | ||
&effects) { | ||
getPackUnPackEffectsImpl(*this, effects); | ||
} | ||
|
||
void UnPackOp::getEffects( | ||
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> | ||
&effects) { | ||
getPackUnPackEffectsImpl(*this, effects); | ||
} | ||
|
||
/// Returns true if the tiles and the tiled dims are constant. | ||
template <typename OpTy> | ||
bool areTilesAndTiledDimsAllConstant(OpTy op) { | ||
|
@@ -4821,6 +4891,9 @@ bool areTilesAndTiledDimsAllConstant(OpTy op) { | |
} | ||
|
||
Speculation::Speculatability PackOp::getSpeculatability() { | ||
if (!hasPureTensorSemantics()) | ||
return Speculation::NotSpeculatable; | ||
|
||
if (getPaddingValue()) | ||
return Speculation::Speculatable; | ||
|
||
|
@@ -4933,7 +5006,8 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { | |
|
||
// Insert tensor.cast ops if static shape inference is available.. | ||
SmallVector<int64_t> srcShape, destShape; | ||
if (inferStaticShape(packOp, srcShape, destShape)) { | ||
if (inferStaticShape(packOp, srcShape, destShape) && | ||
packOp.hasPureTensorSemantics()) { | ||
Location loc = packOp.getLoc(); | ||
Value source = packOp.getSource(); | ||
if (srcShape != packOp.getSourceType().getShape()) { | ||
|
@@ -4942,7 +5016,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { | |
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource()); | ||
} | ||
Value dest = packOp.getDest(); | ||
RankedTensorType originalResultType = packOp.getDestType(); | ||
ShapedType originalResultType = packOp.getDestType(); | ||
hanhanW marked this conversation as resolved.
Show resolved
Hide resolved
|
||
bool needUpdateDestType = (destShape != originalResultType.getShape()); | ||
if (needUpdateDestType) { | ||
auto newDestType = packOp.getDestType().clone(destShape); | ||
|
@@ -4961,15 +5035,15 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { | |
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp); | ||
rewriter.replaceAllUsesExcept(packOp, castOp, castOp); | ||
} | ||
|
||
return success(); | ||
} | ||
|
||
return failure(); | ||
} | ||
|
||
template <typename PackOrUnpackOp> | ||
static bool isLikePadUnPad(PackOrUnpackOp packOp, | ||
RankedTensorType packedTensorType) { | ||
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) { | ||
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value || | ||
std::is_same<PackOrUnpackOp, UnPackOp>::value, | ||
"Function meant for pack/unpack"); | ||
|
@@ -5002,17 +5076,20 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp, | |
|
||
bool PackOp::isLikePad() { | ||
auto packedTensorType = | ||
llvm::cast<RankedTensorType>((*this)->getResultTypes().front()); | ||
llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front()); | ||
return isLikePadUnPad(*this, packedTensorType); | ||
} | ||
|
||
OpFoldResult PackOp::fold(FoldAdaptor adaptor) { | ||
if (!hasPureTensorSemantics()) | ||
return {}; | ||
|
||
std::optional<Attribute> paddingValue; | ||
if (auto pad = adaptor.getPaddingValue()) | ||
paddingValue = pad; | ||
if (OpFoldResult reshapedSource = reshapeConstantSource( | ||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()), | ||
getDestType(), paddingValue)) | ||
cast<TensorType>(getDestType()), paddingValue)) | ||
return reshapedSource; | ||
return {}; | ||
} | ||
|
@@ -5039,6 +5116,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { | |
if (!tensor::hasFoldableTensorCastOperand(op)) | ||
return failure(); | ||
|
||
// TODO: Support Memref PackOp. Temporarily return failure. | ||
if (!op.hasPureTensorSemantics()) | ||
hanhanW marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return failure(); | ||
|
||
SmallVector<Type> newResultTypes(op->getResultTypes()); | ||
SmallVector<Value> newOperands = | ||
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes); | ||
|
@@ -5119,6 +5200,9 @@ LogicalResult UnPackOp::verify() { | |
} | ||
|
||
Speculation::Speculatability UnPackOp::getSpeculatability() { | ||
if (!hasPureTensorSemantics()) | ||
return Speculation::NotSpeculatable; | ||
|
||
// See PackOp::getSpeculatability. | ||
if (!areTilesAndTiledDimsAllConstant(*this)) | ||
return Speculation::NotSpeculatable; | ||
|
@@ -5296,14 +5380,16 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp, | |
} | ||
|
||
bool UnPackOp::isLikeUnPad() { | ||
RankedTensorType packedTensorType = getSourceType(); | ||
ShapedType packedTensorType = getSourceType(); | ||
return isLikePadUnPad(*this, packedTensorType); | ||
} | ||
|
||
OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) { | ||
if (!hasPureTensorSemantics()) | ||
return {}; | ||
if (OpFoldResult reshapedSource = reshapeConstantSource( | ||
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()), | ||
getResult().getType())) | ||
cast<TensorType>(getResult().getType()))) | ||
return reshapedSource; | ||
return {}; | ||
} | ||
|
@@ -5330,6 +5416,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> { | |
if (!tensor::hasFoldableTensorCastOperand(op)) | ||
return failure(); | ||
|
||
// TODO: Support Memref UnPackOp. Temporarily return failure. | ||
if (!op.hasPureTensorSemantics()) | ||
return failure(); | ||
Comment on lines
+5420
to
+5421
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's add a TODO for consistency. It is a reasonable folder to me and we should support it (in follow-ups). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is unpack, not pack. |
||
|
||
SmallVector<Type> newResultTypes(op->getResultTypes()); | ||
SmallVector<Value> newOperands = | ||
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes); | ||
|
Uh oh!
There was an error while loading. Please reload this page.