Skip to content

Commit 17ad838

Browse files
committed
fix upon review
Signed-off-by: Hyunsung Lee <[email protected]>
1 parent 4cbbb80 commit 17ad838

File tree

11 files changed

+135
-79
lines changed

11 files changed

+135
-79
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgRelayoutOps.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,20 +190,29 @@ def Linalg_PackOp : Linalg_RelayoutOp<"pack", [
190190
// Method to get the `RankedTensorType` of the result based on the inner
191191
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
192192
// of outer loops (outerDimsPerm).
193+
/// This method uses inferPackedShape to ensure consistency with other shape
194+
/// inference methods regarding which dimensions are dynamic.
193195
static RankedTensorType inferPackedTensorType(RankedTensorType sourceType,
194196
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
195197
ArrayRef<int64_t> outerDimsPerm = {});
196198

197199
// Method to get the `MemRefType` of the result based on the inner
198200
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
199201
// of outer loops (outerDimsPerm).
202+
/// This method uses inferPackedShape to ensure consistency with other shape
203+
/// inference methods regarding which dimensions are dynamic.
200204
static MemRefType inferPackedMemRefType(MemRefType sourceType,
201205
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
202206
ArrayRef<int64_t> outerDimsPerm = {});
203207

204208
// Method to get the Shape of the result based on the input shape, inner
205209
// tiles, position of the inner tiles (innerDimsPos) and interchange vector
206210
// of outer loops (outerDimsPerm).
211+
212+
/// Helper for PackOp::{getResultShape, inferPackedTensorType, inferPackedMemRefType}.
213+
/// Returns the shape of the packed type. Having a shared helper helps
214+
/// implement these three methods in a way that ensures
215+
/// that they agree on which dimensions are dynamic.
207216
static SmallVector<int64_t> inferPackedShape(ArrayRef<int64_t> inputShape,
208217
ArrayRef<int64_t> innerTileSizes, ArrayRef<int64_t> innerDimsPos,
209218
ArrayRef<int64_t> outerDimsPerm = {});

mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1801,7 +1801,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
18011801
let summary = "store operation";
18021802
let description = [{
18031803
The `store` op stores an element into a memref at the specified indices.
1804-
1804+
18051805
The number of indices must match the rank of the memref. The indices must
18061806
be in-bounds: `0 <= idx < dim_size`
18071807

mlir/include/mlir/Dialect/Utils/ReshapeOpsUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,7 @@ getLinearizedDimensions(ArrayRef<ReassociationIndices> reassociationIndices);
451451
/// %4 = tensor.extract_slice %0 [%3#0, %3#1, %3#2, 0] [1, 1, 1, 10] [1, 1, 1, 1] :
452452
/// tensor<3x7x11x10xf32> to tensor<1x1x1x10xf32>
453453
///
454-
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
454+
/// %5 = tensor.collapse_shape %4 [[0, 1, 2], [3]] :
455455
/// tensor<1x1x1x10xf32> into tensor<1x10xf32>
456456
/// %6 = tensor.insert_slice %5 into %arg0 [%iv, 0] [1, 10] [1, 1] :
457457
/// tensor<1x10xf32> into tensor<10x10xf32>

mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,8 @@ struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
7777
if constexpr (llvm::is_one_of<T, omp::AtomicUpdateOp, omp::AtomicWriteOp,
7878
omp::FlushOp, omp::MapBoundsOp,
7979
omp::ThreadprivateOp>::value) {
80-
if (isa<MemRefType>(originalOperand.getType())) {
81-
// TODO: Support Memref PackOp. Temporarily return failure.
80+
if (isa<MemRefType>(originalOperand.getType()))
8281
return rewriter.notifyMatchFailure(op, "memref is not supported yet");
83-
}
8482
}
8583
convertedOperands.push_back(convertedOperand);
8684
}

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 33 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
// This file implements the Linalg operations.
1010
//
1111
//===----------------------------------------------------------------------===//
12+
1213
#include "mlir/Dialect/Linalg/IR/Linalg.h"
13-
#include <iostream>
1414

1515
#include "mlir/AsmParser/AsmParser.h"
1616
#include "mlir/Dialect/Affine/IR/AffineOps.h"
@@ -4486,13 +4486,23 @@ static LogicalResult commonVerifierPackAndUnPackOp(OpTy packOrUnPack) {
44864486
// Verify result shape is greater than the minimum expected
44874487
// by the pack operation, and that the output shape
44884488
// represents full tiles.
4489-
auto expectedPackedShape = PackOp::inferPackedShape(
4489+
SmallVector<int64_t> expectedPackedShape = PackOp::inferPackedShape(
44904490
unpackedType.getShape(), packOrUnPack.getStaticTiles(),
44914491
packOrUnPack.getInnerDimsPos(), packOrUnPack.getOuterDimsPerm());
4492+
44924493
if (!areAllInBound(expectedPackedShape, packedType.getShape())) {
4494+
auto elementType = unpackedType.getElementType();
4495+
Type expectedType, actualType;
4496+
if (packOrUnPack.hasPureTensorSemantics()) {
4497+
expectedType = RankedTensorType::get(expectedPackedShape, elementType);
4498+
actualType = RankedTensorType::get(packedType.getShape(), elementType);
4499+
} else {
4500+
expectedType = MemRefType::get(expectedPackedShape, elementType);
4501+
actualType = MemRefType::get(packedType.getShape(), elementType);
4502+
}
44934503
return op->emitError("the shape of output is not large enough to hold the "
44944504
"packed data. Expected at least ")
4495-
<< expectedPackedShape << ", got " << packedType.getShape();
4505+
<< expectedType << ", got " << actualType;
44964506
}
44974507
if (!llvm::all_of(
44984508
llvm::zip(packedType.getShape().take_back(mixedTiles.size()),
@@ -4696,13 +4706,11 @@ asShapeWithAnyValueAsDynamic(ArrayRef<OpFoldResult> ofrs) {
46964706
return result;
46974707
}
46984708

4699-
/// Helper for PackOp::{getResultShape,inferPackedTensorType}. Returns the shape
4700-
/// of the packed type. Having a shared helper helps implement these two methods
4701-
/// in a way that ensures that they agree on which dimensions are dynamic.
4702-
static SmallVector<int64_t> getPackOpResultTypeShape(
4703-
ArrayRef<int64_t> sourceShape, ArrayRef<int64_t> innerTileSizes,
4704-
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4705-
SmallVector<int64_t> resultShape = llvm::to_vector(sourceShape);
4709+
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
4710+
ArrayRef<int64_t> innerTileSizes,
4711+
ArrayRef<int64_t> innerDimsPos,
4712+
ArrayRef<int64_t> outerDimsPerm) {
4713+
SmallVector<int64_t> resultShape = llvm::to_vector(inputShape);
47064714
for (auto tiledDim : llvm::enumerate(llvm::to_vector(innerDimsPos))) {
47074715
if (ShapedType::isDynamic(resultShape[tiledDim.value()]))
47084716
continue;
@@ -4742,9 +4750,9 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
47424750
resultDims.append(innerTileSizes.begin(), innerTileSizes.end());
47434751

47444752
SmallVector<int64_t> resultTypeShape =
4745-
getPackOpResultTypeShape(asShapeWithAnyValueAsDynamic(sourceDims),
4746-
asShapeWithAnyValueAsDynamic(innerTileSizes),
4747-
innerDimsPos, outerDimsPerm);
4753+
inferPackedShape(asShapeWithAnyValueAsDynamic(sourceDims),
4754+
asShapeWithAnyValueAsDynamic(innerTileSizes),
4755+
innerDimsPos, outerDimsPerm);
47484756

47494757
// Fix-up `resultDims` to ensure that they are Value's if and only if the
47504758
// result type shape says it's a dynamic dim. This is needed as callers may
@@ -4765,7 +4773,7 @@ SmallVector<OpFoldResult> PackOp::getResultShape(
47654773
RankedTensorType PackOp::inferPackedTensorType(
47664774
RankedTensorType sourceType, ArrayRef<int64_t> innerTileSizes,
47674775
ArrayRef<int64_t> innerDimsPos, ArrayRef<int64_t> outerDimsPerm) {
4768-
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
4776+
SmallVector<int64_t> resultShape = inferPackedShape(
47694777
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
47704778
return RankedTensorType::get(resultShape, sourceType.getElementType());
47714779
}
@@ -4774,19 +4782,11 @@ MemRefType PackOp::inferPackedMemRefType(MemRefType sourceType,
47744782
ArrayRef<int64_t> innerTileSizes,
47754783
ArrayRef<int64_t> innerDimsPos,
47764784
ArrayRef<int64_t> outerDimsPerm) {
4777-
SmallVector<int64_t> resultShape = getPackOpResultTypeShape(
4785+
SmallVector<int64_t> resultShape = inferPackedShape(
47784786
sourceType.getShape(), innerTileSizes, innerDimsPos, outerDimsPerm);
47794787
return MemRefType::get(resultShape, sourceType.getElementType());
47804788
}
47814789

4782-
SmallVector<int64_t> PackOp::inferPackedShape(ArrayRef<int64_t> inputShape,
4783-
ArrayRef<int64_t> innerTileSizes,
4784-
ArrayRef<int64_t> innerDimsPos,
4785-
ArrayRef<int64_t> outerDimsPerm) {
4786-
return getPackOpResultTypeShape(inputShape, innerTileSizes, innerDimsPos,
4787-
outerDimsPerm);
4788-
}
4789-
47904790
Value PackOp::createDestinationTensor(OpBuilder &b, Location loc, Value source,
47914791
ArrayRef<OpFoldResult> innerTileSizes,
47924792
ArrayRef<int64_t> innerDimsPos,
@@ -5004,10 +5004,7 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50045004
return success();
50055005
}
50065006

5007-
// Insert tensor.cast if static shape inference is available..
5008-
bool hasTensorSemantics = packOp.hasPureTensorSemantics();
5009-
5010-
// TODO: support memref.cast if static shape inference is available.
5007+
// Insert tensor.cast ops if static shape inference is available..
50115008
SmallVector<int64_t> srcShape, destShape;
50125009
if (inferStaticShape(packOp, srcShape, destShape)) {
50135010
Location loc = packOp.getLoc();
@@ -5033,15 +5030,19 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
50335030
// Insert a cast if needed
50345031
if (needUpdateDestType) {
50355032
rewriter.setInsertionPointAfter(packOp);
5033+
Operation *castOp;
5034+
bool hasTensorSemantics = packOp.hasPureTensorSemantics();
50365035
if (hasTensorSemantics) {
5037-
auto castOp =
5036+
castOp =
50385037
rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
5039-
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
50405038
} else {
5041-
auto castOp =
5039+
castOp =
50425040
rewriter.create<memref::CastOp>(loc, originalResultType, packOp);
5043-
rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
50445041
}
5042+
rewriter.replaceAllUsesExcept(packOp, castOp->getResult(0), castOp);
5043+
} else {
5044+
// TODO: support memref.cast if static shape inference is available.
5045+
return failure();
50455046
}
50465047
return success();
50475048
}
@@ -5423,6 +5424,7 @@ struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
54235424
if (!tensor::hasFoldableTensorCastOperand(op))
54245425
return failure();
54255426

5427+
// TODO: Support Memref PackOp. Temporarily return failure.
54265428
if (!op.hasPureTensorSemantics())
54275429
return failure();
54285430

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,6 @@ getPackingInfoFromOperand(OpOperand *opOperand, linalg::GenericOp genericOp,
6363
OpTy packOrUnPackOp) {
6464
static_assert(llvm::is_one_of<OpTy, linalg::PackOp, linalg::UnPackOp>::value,
6565
"applies to only pack or unpack operations");
66-
if (!packOrUnPackOp.hasPureTensorSemantics())
67-
return failure();
68-
6966
LLVM_DEBUG(
7067
{ llvm::dbgs() << "--- Construct PackInfo From an operand ---\n"; });
7168

@@ -376,9 +373,6 @@ static GenericOp packGenericOp(RewriterBase &rewriter, GenericOp genericOp,
376373
static FailureOr<GenericOp>
377374
bubbleUpPackOpThroughGenericOp(RewriterBase &rewriter, linalg::PackOp packOp,
378375
const ControlPropagationFn &controlFn) {
379-
if (!packOp.hasPureTensorSemantics())
380-
return failure();
381-
382376
auto genericOp = packOp.getSource().getDefiningOp<GenericOp>();
383377
if (!genericOp)
384378
return failure();

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,6 @@ struct PackedOperandsDimList {
219219
FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
220220
linalg::PackOp packOp,
221221
bool lowerPadLikeWithInsertSlice) {
222-
// TODO: Support Memref PackOp. Temporarily return failure.
223222
if (!packOp.hasPureTensorSemantics())
224223
return failure();
225224

@@ -269,6 +268,7 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
269268
highs[pos] = affine::makeComposedFoldedAffineApply(
270269
rewriter, loc, map, {outerSize, origSize, innerSize});
271270
}
271+
// TODO: Need memref.pad operation to support memref operands
272272
RankedTensorType collapsed = tensor::CollapseShapeOp::inferCollapsedType(
273273
RankedTensorType::Builder(packedTensorType).setShape(stripMinedShape),
274274
packingMetadata.reassociations);
@@ -359,19 +359,11 @@ FailureOr<LowerPackResult> linalg::lowerPack(RewriterBase &rewriter,
359359
FailureOr<LowerUnPackOpResult>
360360
linalg::lowerUnPack(RewriterBase &rewriter, linalg::UnPackOp unPackOp,
361361
bool lowerUnpadLikeWithExtractSlice) {
362-
// TODO: Support Memref PackOp. Temporarily return failure.
363-
if (!unPackOp.hasPureTensorSemantics()) {
364-
return failure();
365-
}
366-
367362
Location loc = unPackOp->getLoc();
368363
OpBuilder::InsertionGuard g(rewriter);
369364
rewriter.setInsertionPoint(unPackOp);
370365

371-
auto packedTensorType = dyn_cast<RankedTensorType>(unPackOp.getSourceType());
372-
if (!packedTensorType)
373-
return failure();
374-
366+
auto packedTensorType = cast<RankedTensorType>(unPackOp.getSourceType());
375367
int64_t packedRank = packedTensorType.getRank();
376368

377369
OpFoldResult zero = rewriter.getIndexAttr(0), one = rewriter.getIndexAttr(1);
@@ -1038,14 +1030,14 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite(
10381030
static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10391031
linalg::PackOp packOp) {
10401032
Value input = packOp.getSource();
1033+
// TODO: Support Memref PackOp. Temporarily return just Op Source.
1034+
if (!packOp.hasPureTensorSemantics())
1035+
return input;
1036+
10411037
if (!packOp.getPaddingValue()) {
10421038
return input;
10431039
}
10441040

1045-
// TODO: Support Memref PackOp. Temporarily return failure.
1046-
if (!packOp.hasPureTensorSemantics())
1047-
return packOp.getSource();
1048-
10491041
assert(llvm::all_of(packOp.getAllOuterDims(),
10501042
[](int64_t val) { return val == 1; }) &&
10511043
"some outer dims are != 1");
@@ -1158,7 +1150,6 @@ getPackUnpackRankReducedPerm(ArrayRef<int64_t> shape,
11581150

11591151
LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
11601152
linalg::PackOp packOp, PatternRewriter &rewriter) const {
1161-
// TODO: Support Memref PackOp. Temporarily return failure.
11621153
if (!packOp.hasPureTensorSemantics())
11631154
return failure();
11641155

@@ -1263,10 +1254,8 @@ LogicalResult DecomposeOuterUnitDimsPackOpPattern::matchAndRewrite(
12631254

12641255
LogicalResult DecomposeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12651256
linalg::UnPackOp unpackOp, PatternRewriter &rewriter) const {
1266-
// TODO: Support Memref PackOp. Temporarily return failure.
1267-
if (!unpackOp.hasPureTensorSemantics()) {
1257+
if (!unpackOp.hasPureTensorSemantics())
12681258
return failure();
1269-
}
12701259

12711260
int64_t srcRank = unpackOp.getSourceRank();
12721261
int64_t destRank = unpackOp.getDestRank();

mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1588,10 +1588,6 @@ static LogicalResult
15881588
vectorizeAsTensorPackOp(RewriterBase &rewriter, linalg::PackOp packOp,
15891589
ArrayRef<int64_t> inputVectorSizes,
15901590
SmallVectorImpl<Value> &newResults) {
1591-
// TODO: Support Memref PackOp. Temporarily return failure.
1592-
if (!packOp.hasPureTensorSemantics())
1593-
return failure();
1594-
15951591
// TODO: Introduce a parent class that will handle the insertion point update.
15961592
OpBuilder::InsertionGuard g(rewriter);
15971593
rewriter.setInsertionPoint(packOp);
@@ -1668,17 +1664,11 @@ static LogicalResult
16681664
vectorizeAsTensorUnpackOp(RewriterBase &rewriter, linalg::UnPackOp unpackOp,
16691665
ArrayRef<int64_t> inputVectorSizes,
16701666
SmallVectorImpl<Value> &newResults) {
1671-
// TODO: Support Memref PackOp. Temporarily return failure.
1672-
if (!unpackOp.hasPureTensorSemantics())
1673-
return failure();
1674-
16751667
// TODO: Introduce a parent class that will handle the insertion point update.
16761668
OpBuilder::InsertionGuard g(rewriter);
16771669
rewriter.setInsertionPoint(unpackOp);
16781670

1679-
auto unpackTensorType = dyn_cast<RankedTensorType>(unpackOp.getSourceType());
1680-
if (!unpackTensorType)
1681-
return failure();
1671+
auto unpackTensorType = cast<RankedTensorType>(unpackOp.getSourceType());
16821672

16831673
ArrayRef<int64_t> innerDimPos = unpackOp.getInnerDimsPos();
16841674
ArrayRef<int64_t> innerTiles = unpackOp.getStaticInnerTiles();

mlir/lib/Dialect/Utils/ReshapeOpsUtils.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -315,11 +315,12 @@ SmallVector<Range> SliceFromCollapseHelper::getExtractSliceParams(
315315
// have proven that these are not sliced. In this case we just take
316316
// the full extent of each dimension in the reassociation list.
317317
if (linearizedDimensions[it.index()]) {
318-
llvm::append_range(offsetsSizesAndStrides,
319-
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
320-
return {zeroAttr, collapseShapeInputShape[idx],
321-
oneAttr};
322-
}));
318+
llvm::append_range(
319+
offsetsSizesAndStrides,
320+
llvm::map_range(it.value(), [&](int64_t idx) -> Range {
321+
return {zeroAttr, collapseShapeInputShape[idx], oneAttr};
322+
}));
323+
323324
continue;
324325
}
325326

0 commit comments

Comments
 (0)