Skip to content

[MLIR][Linalg] pack, unpack to take memref inputs #129036

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4d523ad
draft
ita9naiwa Feb 27, 2025
4f2dbf4
draft
ita9naiwa Feb 27, 2025
226230c
draft
ita9naiwa Feb 27, 2025
0c184df
init
ita9naiwa Feb 28, 2025
19201c6
lint
ita9naiwa Feb 28, 2025
b99b920
lint
ita9naiwa Feb 28, 2025
be6a119
add
ita9naiwa Feb 28, 2025
eee8805
remove tensor casting
ita9naiwa Mar 1, 2025
c5b3c39
add test
ita9naiwa Mar 1, 2025
714e4c4
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Mar 16, 2025
a5d01df
fix upon review
ita9naiwa Mar 16, 2025
2480616
lint
ita9naiwa Mar 23, 2025
0421e72
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Mar 23, 2025
7b92a4e
format fix
ita9naiwa Mar 24, 2025
6dc08ae
revert changes
ita9naiwa Mar 25, 2025
cf7be57
revert changes
ita9naiwa Mar 25, 2025
4e2f00d
nit
ita9naiwa Mar 25, 2025
ee7a42a
fix upon review: Add getEffects for PackOp and UnPackOp
ita9naiwa Mar 27, 2025
5b95ee8
make clang-format happy
ita9naiwa Mar 27, 2025
8b5ac5a
make clang-format happy
ita9naiwa Mar 27, 2025
c955d21
wrap getEffects function
ita9naiwa Mar 27, 2025
4bedf40
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Mar 27, 2025
276069d
fix upon review
ita9naiwa Mar 29, 2025
790e974
bail out transforms using PackOp, UnPackOp
ita9naiwa Mar 30, 2025
820e40b
fix build error
ita9naiwa Mar 30, 2025
43a64b9
fix build error
ita9naiwa Mar 30, 2025
a3bba60
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Apr 2, 2025
486c62b
add invalid pack/unpack cases
ita9naiwa Apr 2, 2025
ca889b5
fix roundtrip test
ita9naiwa Apr 2, 2025
ce910b9
fix upon review
ita9naiwa Apr 2, 2025
6a501bd
fix upon review
ita9naiwa Apr 2, 2025
535e796
.
ita9naiwa Apr 2, 2025
4cbbb80
Merge branch 'main' into ita9naiwa/pack-memref
ita9naiwa Apr 6, 2025
17ad838
fix upon review
ita9naiwa Apr 13, 2025
2aca3fd
Update mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
ita9naiwa Apr 20, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 24 additions & 12 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
Op<Linalg_Dialect, mnemonic, !listconcat(traits, [
DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
DestinationStyleOpInterface, LinalgRelayoutOpInterface,
ConditionallySpeculatable, NoMemoryEffect,
ConditionallySpeculatable, DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
TypesMatchWith<"result type matches type of dest",
"dest", "result",
Expand All @@ -43,10 +43,10 @@ class Linalg_RelayoutOp<string mnemonic, list<Trait> traits = []> :
code commonExtraClassDeclaration = [{
size_t getSourceRank() { return getSourceType().getRank(); };
size_t getDestRank() { return getDestType().getRank(); };
RankedTensorType getSourceType() {
return ::llvm::cast<RankedTensorType>(getSource().getType()); };
RankedTensorType getDestType() {
return ::llvm::cast<RankedTensorType>(getDest().getType()); };
ShapedType getSourceType() {
return ::llvm::cast<ShapedType>(getSource().getType()); };
ShapedType getDestType() {
return ::llvm::cast<ShapedType>(getDest().getType()); };

MutableOperandRange getDpsInitsMutable() { return getDestMutable(); }

Expand Down Expand Up @@ -152,14 +152,14 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Note: Only tiled dimensions can be padded.
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
let arguments = (ins AnyShaped:$source,
AnyShaped:$dest,
Optional<AnyType>:$padding_value,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`padding_value` `(` $padding_value^ `:` type($padding_value) `)`)?
Expand Down Expand Up @@ -190,7 +190,19 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
// Method to get the `RankedTensorType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static RankedTensorType inferPackedType(RankedTensorType sourceType,
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

// Method to get the `MemRefType` of the result based on the inner
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
// of outer loops (outerDimsPerm).
static MemRefType inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

// Returns the shape of the packed type. It is a shared helper helps type inference methods in a way that ensures that they agree on which dimensions are dynamic.
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm = {});

Expand Down Expand Up @@ -279,13 +291,13 @@ def Linalg_UnPackOp : Linalg_RelayoutOp<"unpack"> {
: tensor<8x16x8x32xf32> -> tensor<128x256xf32>
```
}];
let arguments = (ins AnyRankedTensor:$source,
AnyRankedTensor:$dest,
let arguments = (ins AnyShaped:$source,
AnyShaped:$dest,
DefaultValuedOptionalAttr<DenseI64ArrayAttr, "{}">:$outer_dims_perm,
DenseI64ArrayAttr:$inner_dims_pos,
Variadic<Index>:$inner_tiles,
DenseI64ArrayAttr:$static_inner_tiles);
let results = (outs AnyRankedTensor:$result);
let results = (outs AnyShaped:$result);
let assemblyFormat = [{
$source
(`outer_dims_perm` `=` $outer_dims_perm^)?
Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
omp::FlushOp, omp::MapBoundsOp,
omp::ThreadprivateOp>::value) {
if (isa<MemRefType>(originalOperand.getType())) {
// TODO: Support memref type in variable operands
if (isa<MemRefType>(originalOperand.getType()))
return rewriter.notifyMatchFailure(op, "memref is not supported yet");
}
}
convertedOperands.push_back(convertedOperand);
}
Expand Down
148 changes: 119 additions & 29 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Error.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/LogicalResult.h"
#include "llvm/Support/MathExtras.h"
Expand Down Expand Up @@ -804,7 +805,7 @@ struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
}

RankedTensorType srcPadType = srcPadOp.getSourceType();
ShapedType srcPadType = srcPadOp.getSourceType();
SmallVector<OpFoldResult, 4> newSizes;
for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
if (srcPadType.isDynamicDim(i)) {
Expand Down Expand Up @@ -4427,15 +4428,28 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
tiles, [](OpFoldResult tile) { return isConstantIntValue(tile, 0); });
};

// Verify that the source and destination are ranked types.
if (!packOrUnPack.getSourceType().hasRank() ||
!packOrUnPack.getDestType().hasRank()) {
return op->emitError("expected both source and destination to have rank");
}

// Verify tiles. Do not allow zero tiles.
SmallVector<OpFoldResult> mixedTiles = packOrUnPack.getMixedTiles();
if (hasZeros(mixedTiles))
return op->emitError("invalid zero tile factor");

// Verify that the Operation does not have mixed tensor/buffer semantics.
if (!packOrUnPack.hasPureBufferSemantics() &&
!packOrUnPack.hasPureTensorSemantics()) {
return op->emitError("mixing tensor and buffer semantics is not allowed");
}

// Verify inner_dims_pos and outer_dims_perm.
RankedTensorType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getSourceType()
: packOrUnPack.getDestType();
ShapedType unpackedType = (std::is_same<OpTy, PackOp>::value)
? packOrUnPack.getSourceType()
: packOrUnPack.getDestType();

size_t unpackedRank = unpackedType.getRank();
ArrayRef<int64_t> innerDimsPos = packOrUnPack.getInnerDimsPos();
ArrayRef<int64_t> outerDimPerm = packOrUnPack.getOuterDimsPerm();
Expand Down Expand Up @@ -4472,12 +4486,23 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
// Verify result shape is greater than the minimum expected
// by the pack operation, and that the output shape
// represents full tiles.
RankedTensorType expectedPackedType = PackOp::inferPackedType(
unpackedType, packOrUnPack.getStaticTiles(), innerDimsPos, outerDimPerm);
if (!areAllInBound(expectedPackedType.getShape(), packedType.getShape())) {
SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
unpackedType.getShape(), packOrUnPack.getStaticTiles(),
packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());

if (!areAllInBound(expectedPackedShape, packedType.getShape())) {
auto elementType = unpackedType.getElementType();
Type expectedType, actualType;
if (packOrUnPack.hasPureTensorSemantics()) {
expectedType = RankedTensorType::get(expectedPackedShape, elementType);
actualType = RankedTensorType::get(packedType.getShape(), elementType);
} else {
expectedType = MemRefType::get(expectedPackedShape, elementType);
actualType = MemRefType::get(packedType.getShape(), elementType);
}
return op->emitError("the shape of output is not large enough to hold the "
"packed data. Expected at least ")
<< expectedPackedType << ", got " << packedType;
<< expectedType << ", got " << actualType;
}
if (!llvm::all_of(
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
Expand Down Expand Up @@ -4681,13 +4706,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
return result;
}

/// Helper for PackOp::{getResultShape,inferPackedType}. Returns the shape of
/// the packed type. Having a shared helper helps implement these two methods in
/// a way that ensures that they agree on which dimensions are dynamic.
static SmallVector<int64_t> getPackOpResultTypeShape(
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
continue;
Expand Down Expand Up @@ -4727,9 +4750,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());

SmallVector<int64_t> resultTypeShape =
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
asShapeWithAnyValueAsDynamic(innerTileSizes),
innerDimsPos, outerDimsPerm);
inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
asShapeWithAnyValueAsDynamic(innerTileSizes),
innerDimsPos, outerDimsPerm);

// Fix-up `resultDims` to ensure that they are Value's if and only if the
// result type shape says it's a dynamic dim. This is needed as callers may
Expand All @@ -4747,13 +4770,21 @@ SmallVector<OpFoldResult> PackOp::getResultShape(

/// Get the expected packed type based on source type, tile factors, position of
/// the inner tiles and permutation of the outer tiled loop.
RankedTensorType PackOp::inferPackedType(RankedTensorType sourceType,
RankedTensorType PackOp::inferPackedTensorType(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above comment needs to be removed because they are documented in the declaration.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bump

RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = inferPackedShape(
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
return RankedTensorType::get(resultShape, sourceType.getElementType());
}

MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
ArrayRef<int64_t> innerTileSizes,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outerDimsPerm) {
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
SmallVector<int64_t> resultShape = inferPackedShape(
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
return RankedTensorType::get(resultShape, sourceType.getElementType());
return MemRefType::get(resultShape, sourceType.getElementType());
}

Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
Expand Down Expand Up @@ -4802,6 +4833,45 @@ PackOp PackOp::createTransposedClone(OpBuilder &b, Location loc,
getPaddingValue(), metadata.outerDimsPerm);
}

template <typename OpTy>
static void getPackUnPackEffectsImpl(
OpTy op, SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
// No memory effects for pure tensor semantics
if (op.hasPureTensorSemantics())
return;

for (OpOperand &opOperand : op.getOperation()->getOpOperands()) {
if (!llvm::isa<MemRefType>(opOperand.get().getType()))
continue;

if (&opOperand == &op.getSourceMutable()) {
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
} else if (&opOperand == &op.getDestMutable()) {
effects.emplace_back(MemoryEffects::Read::get(), &opOperand, /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
effects.emplace_back(MemoryEffects::Write::get(), &opOperand, /*stage=*/0,
/*effectOnFullRegion=*/true,
SideEffects::DefaultResource::get());
}
}
}

void PackOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getPackUnPackEffectsImpl(*this, effects);
}

void UnPackOp::getEffects(
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
&effects) {
getPackUnPackEffectsImpl(*this, effects);
}

/// Returns true if the tiles and the tiled dims are constant.
template <typename OpTy>
bool areTilesAndTiledDimsAllConstant(OpTy op) {
Expand All @@ -4821,6 +4891,9 @@ bool areTilesAndTiledDimsAllConstant(OpTy op) {
}

Speculation::Speculatability PackOp::getSpeculatability() {
if (!hasPureTensorSemantics())
return Speculation::NotSpeculatable;

if (getPaddingValue())
return Speculation::Speculatable;

Expand Down Expand Up @@ -4933,7 +5006,8 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {

// Insert tensor.cast ops if static shape inference is available..
SmallVector<int64_t> srcShape, destShape;
if (inferStaticShape(packOp, srcShape, destShape)) {
if (inferStaticShape(packOp, srcShape, destShape) &&
packOp.hasPureTensorSemantics()) {
Location loc = packOp.getLoc();
Value source = packOp.getSource();
if (srcShape != packOp.getSourceType().getShape()) {
Expand All @@ -4942,7 +5016,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
RankedTensorType originalResultType = packOp.getDestType();
ShapedType originalResultType = packOp.getDestType();
bool needUpdateDestType = (destShape != originalResultType.getShape());
if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
Expand All @@ -4961,15 +5035,15 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
}

return success();
}

return failure();
}

template <typename PackOrUnpackOp>
static bool isLikePadUnPad(PackOrUnpackOp packOp,
RankedTensorType packedTensorType) {
static bool isLikePadUnPad(PackOrUnpackOp packOp, ShapedType packedTensorType) {
static_assert(std::is_same<PackOrUnpackOp, PackOp>::value ||
std::is_same<PackOrUnpackOp, UnPackOp>::value,
"Function meant for pack/unpack");
Expand Down Expand Up @@ -5002,17 +5076,20 @@ static bool isLikePadUnPad(PackOrUnpackOp packOp,

bool PackOp::isLikePad() {
auto packedTensorType =
llvm::cast<RankedTensorType>((*this)->getResultTypes().front());
llvm::dyn_cast<ShapedType>((*this)->getResultTypes().front());
return isLikePadUnPad(*this, packedTensorType);
}

OpFoldResult PackOp::fold(FoldAdaptor adaptor) {
if (!hasPureTensorSemantics())
return {};

std::optional<Attribute> paddingValue;
if (auto pad = adaptor.getPaddingValue())
paddingValue = pad;
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getDestType(), paddingValue))
cast<TensorType>(getDestType()), paddingValue))
return reshapedSource;
return {};
}
Expand All @@ -5039,6 +5116,10 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
if (!tensor::hasFoldableTensorCastOperand(op))
return failure();

// TODO: Support Memref PackOp. Temporarily return failure.
if (!op.hasPureTensorSemantics())
return failure();

SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands =
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
Expand Down Expand Up @@ -5119,6 +5200,9 @@ LogicalResult UnPackOp::verify() {
}

Speculation::Speculatability UnPackOp::getSpeculatability() {
if (!hasPureTensorSemantics())
return Speculation::NotSpeculatable;

// See PackOp::getSpeculatability.
if (!areTilesAndTiledDimsAllConstant(*this))
return Speculation::NotSpeculatable;
Expand Down Expand Up @@ -5296,14 +5380,16 @@ LogicalResult UnPackOp::canonicalize(UnPackOp unPackOp,
}

bool UnPackOp::isLikeUnPad() {
RankedTensorType packedTensorType = getSourceType();
ShapedType packedTensorType = getSourceType();
return isLikePadUnPad(*this, packedTensorType);
}

OpFoldResult UnPackOp::fold(FoldAdaptor adaptor) {
if (!hasPureTensorSemantics())
return {};
if (OpFoldResult reshapedSource = reshapeConstantSource(
llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getSource()),
getResult().getType()))
cast<TensorType>(getResult().getType())))
return reshapedSource;
return {};
}
Expand All @@ -5330,6 +5416,10 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
if (!tensor::hasFoldableTensorCastOperand(op))
return failure();

// TODO: Support Memref UnPackOp. Temporarily return failure.
if (!op.hasPureTensorSemantics())
return failure();
Comment on lines +5420 to +5421
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add a TODO for consistency. It is a reasonable folder to me and we should support it (in follow-ups).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is unpack, not pack.


SmallVector<Type> newResultTypes(op->getResultTypes());
SmallVector<Value> newOperands =
tensor::getUpdatedOperandsAfterCastOpFolding(op, newResultTypes);
Expand Down
Loading
Loading