Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 97 additions & 0 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,102 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
results.add<SelfConcatToTile>(context);
}

/* Rewrites reshapes that are adding 1-sized dims and are followed by a tile on
* the one-sized dim to a tile on the original shape followed by a reshape. This
* is done to reduce the rank of tile ops. */
Comment on lines +112 to +114
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the core canonicalization to minimize the dimensions needed to express the tile operation? If so it would be better to do so regardless of a directly connected reshape or not and instead insert reshape before and after as appropriate.

If this is not the core goal then I question having this as a canonicalization as matching tile will still have a guarantee that the tile op is in canonicalized form.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not think its a good idea to always do the tile rank reduction, as the inserted reshapes may prevent other patterns from matching.

I originally wrote this to target a specific case:

reshape 1x2x3 to 1x2x1x3
tile 1x2x1x3 to 1x2x4x3
reshape to 1x2x12

this could be canonicalized to :
tile 1x2x3 to 1x2x12

In this case the rank of the tile would be reduced and the reshapes completely canceled.

I marked this Pr as draft, as I think it generally needs more investigation when it makes sense to do this and when not

struct TileOnOneSizedDim : public OpRewritePattern<tosa::TileOp> {
using OpRewritePattern<tosa::TileOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::TileOp tileOp,
PatternRewriter &rewriter) const override {
SmallVector<int64_t> multiplies;
if (failed(tileOp.getConstantMultiples(multiplies))) {
return rewriter.notifyMatchFailure(tileOp, "Requires const multiplies");
}
auto tileResultType = dyn_cast<ShapedType>(tileOp.getResult().getType());
if (!tileResultType || !tileResultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(tileOp, "Requires static shaped types");
}
const auto originalTileShape = tileResultType.getShape();
auto producerReshape = tileOp.getInput1().getDefiningOp<ReshapeOp>();
if (!producerReshape) {
return rewriter.notifyMatchFailure(tileOp, "Producer is not a reshape");
}

auto reshapeInputType =
dyn_cast<ShapedType>(producerReshape->getOperand(0).getType());
auto reshapeResultType =
dyn_cast<ShapedType>(producerReshape->getResult(0).getType());
if (!reshapeInputType || !reshapeResultType ||
!reshapeInputType.hasStaticShape() ||
!reshapeResultType.hasStaticShape()) {
return rewriter.notifyMatchFailure(tileOp, "Requires static shaped types");
}
const auto reshapeInShape = reshapeInputType.getShape();
const auto reshapeOutShape = reshapeResultType.getShape();
std::optional<size_t> firstAddedOneDim;
for (auto [idx, outDim] : llvm::enumerate(reshapeOutShape)) {
if (idx >= reshapeInShape.size()) {
return rewriter.notifyMatchFailure(
tileOp, "Did not find reshape just adding ones");
}
if (outDim == reshapeInShape[idx]) {
continue;
}
if (outDim == 1 && idx + 1 < reshapeOutShape.size() &&
reshapeOutShape[idx + 1] == reshapeInShape[idx]) {
firstAddedOneDim = idx;
break;
}
}
if (!firstAddedOneDim) {
return rewriter.notifyMatchFailure(
tileOp, "Producer reshape is not only adding one dims");
}
if (multiplies[*firstAddedOneDim] == 1) {
return rewriter.notifyMatchFailure(
tileOp, "Tile is not on a one sized dimension");
}
RewriterBase::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(tileOp);
SmallVector<int64_t> reshapeWithoutOneShape;
reshapeWithoutOneShape.reserve(reshapeOutShape.size() - 1);
for (auto [idx, dim] : llvm::enumerate(reshapeOutShape)) {
if (idx != *firstAddedOneDim) {
reshapeWithoutOneShape.push_back(dim);
}
}
auto removeOneDimReshape = rewriter.createOrFold<tosa::ReshapeOp>(
tileOp.getLoc(), producerReshape.getResult(), reshapeWithoutOneShape);
SmallVector<int64_t> newTileMultiples;
newTileMultiples.reserve(multiplies.size());
for (auto [idx, multiplie] : llvm::enumerate(multiplies)) {
if (idx == *firstAddedOneDim) {
continue;
}
if (idx == *firstAddedOneDim + 1) {
newTileMultiples.push_back(multiplie * multiplies[*firstAddedOneDim]);
} else {
newTileMultiples.push_back(multiplie);
}
}
auto newTileConstMults =
getTosaConstShape(rewriter, tileOp.getLoc(), newTileMultiples);
SmallVector<int64_t> newTileResultShape;
for (auto [dim, mult] :
llvm::zip_equal(reshapeWithoutOneShape, newTileMultiples)) {
newTileResultShape.push_back(dim * mult);
}
auto newTileResultType = tileResultType.clone(newTileResultShape);
auto newTileOp = rewriter.create<tosa::TileOp>(
tileOp.getLoc(), newTileResultType, removeOneDimReshape, newTileConstMults);
auto reshapeToOriginalResultShape = rewriter.create<tosa::ReshapeOp>(
tileOp.getLoc(), newTileOp.getResult(), originalTileShape);
rewriter.replaceOp(tileOp, reshapeToOriginalResultShape.getResult());
return success();
}
};

struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
using OpRewritePattern<tosa::TileOp>::OpRewritePattern;

Expand Down Expand Up @@ -155,6 +251,7 @@ struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {

void TileOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<TileOnOneSizedDim>(context);
results.add<FuseChainedTile>(context);
}

Expand Down
71 changes: 71 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
// Modifications (c) Copyright 2023 - 2025 Advanced Micro Devices, Inc. or its affiliates
// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s

// CHECK-LABEL: @argmax_nofold
Expand Down Expand Up @@ -989,6 +990,76 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {

// -----


// CHECK-LABEL: func.func @tile_on_one_sized_dim_front
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>) -> tensor<4x2x3xf32>
// CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[4, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
// CHECK: %[[TILE:.*]] = tosa.tile %[[ARG]], %[[CST]] : (tensor<2x3xf32>, !tosa.shape<2>) -> tensor<8x3xf32>
// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[TILE]] {new_shape = array<i64: 4, 2, 3>} : (tensor<8x3xf32>) -> tensor<4x2x3xf32>
// CHECK: return %[[RESHAPE]] : tensor<4x2x3xf32>
func.func @tile_on_one_sized_dim_front(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
%r = tosa.reshape %arg0 {new_shape = array<i64: 1, 2, 3>} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
%m = tosa.const_shape {value = dense<[4, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%t = tosa.tile %r, %m : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<4x2x3xf32>
return %t : tensor<4x2x3xf32>
}

// -----

// CHECK-LABEL: func.func @tile_on_one_sized_dim_middle
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3x5xf32>) -> tensor<2x4x3x5xf32>
// CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
// CHECK: %[[TILE:.*]] = tosa.tile %[[ARG]], %[[CST]] : (tensor<2x3x5xf32>, !tosa.shape<3>) -> tensor<2x12x5xf32>
// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[TILE]] {new_shape = array<i64: 2, 4, 3, 5>} : (tensor<2x12x5xf32>) -> tensor<2x4x3x5xf32>
// CHECK: return %[[RESHAPE]] : tensor<2x4x3x5xf32>
func.func @tile_on_one_sized_dim_middle(%arg0: tensor<2x3x5xf32>) -> tensor<2x4x3x5xf32> {
%r = tosa.reshape %arg0 {new_shape = array<i64: 2, 1, 3, 5>} : (tensor<2x3x5xf32>) -> tensor<2x1x3x5xf32>
%m = tosa.const_shape {value = dense<[1, 4, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
%t = tosa.tile %r, %m : (tensor<2x1x3x5xf32>, !tosa.shape<4>) -> tensor<2x4x3x5xf32>
return %t : tensor<2x4x3x5xf32>
}

// -----

// Negative: trailing added 1.
// CHECK-LABEL: func.func @tile_on_one_sized_dim_trailing_one_no_match
// CHECK: tosa.reshape %{{.*}} {new_shape = array<i64: 2, 3, 1>} : (tensor<2x3xf32>) -> tensor<2x3x1xf32>
// CHECK: tosa.tile %{{.*}} : (tensor<2x3x1xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
func.func @tile_on_one_sized_dim_trailing_one_no_match(%arg0: tensor<2x3xf32>) -> tensor<2x3x4xf32> {
%r = tosa.reshape %arg0 {new_shape = array<i64: 2, 3, 1>} : (tensor<2x3xf32>) -> tensor<2x3x1xf32>
%m = tosa.const_shape {value = dense<[1, 1, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
%t = tosa.tile %r, %m : (tensor<2x3x1xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
return %t : tensor<2x3x4xf32>
}

// -----

// Negative: multiplier on inserted 1-dim is 1.
// CHECK-LABEL: func.func @tile_on_one_sized_dim_multiplier_one_no_match
// CHECK: tosa.reshape %{{.*}} {new_shape = array<i64: 1, 2, 3>}
// CHECK: tosa.tile %{{.*}} : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<1x4x3xf32>
func.func @tile_on_one_sized_dim_multiplier_one_no_match(%arg0: tensor<2x3xf32>) -> tensor<1x4x3xf32> {
%r = tosa.reshape %arg0 {new_shape = array<i64: 1, 2, 3>} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
%m = tosa.const_shape {value = dense<[1, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%t = tosa.tile %r, %m : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<1x4x3xf32>
return %t : tensor<1x4x3xf32>
}

// -----

// Negative: reshape not only adding ones.
// CHECK-LABEL: func.func @tile_on_one_sized_dim_not_only_adding_ones_no_match
// CHECK: tosa.reshape %{{.*}} {new_shape = array<i64: 1, 3, 4>}
// CHECK: tosa.tile %{{.*}} : (tensor<1x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
func.func @tile_on_one_sized_dim_not_only_adding_ones_no_match(%arg0: tensor<2x6xf32>) -> tensor<2x3x4xf32> {
%r = tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 4>} : (tensor<2x6xf32>) -> tensor<1x3x4xf32>
%m = tosa.const_shape {value = dense<[2, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
%t = tosa.tile %r, %m : (tensor<1x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
return %t : tensor<2x3x4xf32>
}

// -----

// CHECK-LABEL: @transpose_no_op
func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
// CHECK: return %arg0
Expand Down
Loading