Skip to content

Commit 7276ee6

Browse files
committed
Rewrites reshapes that are adding 1-sized dims and are followed by a tile on the one-sized dim to a tile on the original shape followed by a reshape.
This is done to reduce the rank of tile ops.
1 parent 0e919ee commit 7276ee6

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,102 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
109109
results.add<SelfConcatToTile>(context);
110110
}
111111

112+
/* Rewrites reshapes that are adding 1-sized dims and are followed by a tile on
113+
* the one-sized dim to a tile on the original shape followed by a reshape. This
114+
* is done to reduce the rank of tile ops. */
115+
struct TileOnOneSizedDim : public OpRewritePattern<tosa::TileOp> {
116+
using OpRewritePattern<tosa::TileOp>::OpRewritePattern;
117+
118+
LogicalResult matchAndRewrite(tosa::TileOp tileOp,
119+
PatternRewriter &rewriter) const override {
120+
SmallVector<int64_t> multiplies;
121+
if (failed(tileOp.getConstantMultiples(multiplies))) {
122+
return rewriter.notifyMatchFailure(tileOp, "Requires const multiplies");
123+
}
124+
auto tileResultType = dyn_cast<ShapedType>(tileOp.getResult().getType());
125+
if (!tileResultType || !tileResultType.hasStaticShape()) {
126+
return rewriter.notifyMatchFailure(tileOp, "Requires static shaped types");
127+
}
128+
const auto originalTileShape = tileResultType.getShape();
129+
auto producerReshape = tileOp.getInput1().getDefiningOp<ReshapeOp>();
130+
if (!producerReshape) {
131+
return rewriter.notifyMatchFailure(tileOp, "Producer is not a reshape");
132+
}
133+
134+
auto reshapeInputType =
135+
dyn_cast<ShapedType>(producerReshape->getOperand(0).getType());
136+
auto reshapeResultType =
137+
dyn_cast<ShapedType>(producerReshape->getResult(0).getType());
138+
if (!reshapeInputType || !reshapeResultType ||
139+
!reshapeInputType.hasStaticShape() ||
140+
!reshapeResultType.hasStaticShape()) {
141+
return rewriter.notifyMatchFailure(tileOp, "Requires static shaped types");
142+
}
143+
const auto reshapeInShape = reshapeInputType.getShape();
144+
const auto reshapeOutShape = reshapeResultType.getShape();
145+
std::optional<size_t> firstAddedOneDim;
146+
for (auto [idx, outDim] : llvm::enumerate(reshapeOutShape)) {
147+
if (idx >= reshapeInShape.size()) {
148+
return rewriter.notifyMatchFailure(
149+
tileOp, "Did not find reshape just adding ones");
150+
}
151+
if (outDim == reshapeInShape[idx]) {
152+
continue;
153+
}
154+
if (outDim == 1 && idx + 1 < reshapeOutShape.size() &&
155+
reshapeOutShape[idx + 1] == reshapeInShape[idx]) {
156+
firstAddedOneDim = idx;
157+
break;
158+
}
159+
}
160+
if (!firstAddedOneDim) {
161+
return rewriter.notifyMatchFailure(
162+
tileOp, "Producer reshape is not only adding one dims");
163+
}
164+
if (multiplies[*firstAddedOneDim] == 1) {
165+
return rewriter.notifyMatchFailure(
166+
tileOp, "Tile is not on a one sized dimension");
167+
}
168+
RewriterBase::InsertionGuard guard(rewriter);
169+
rewriter.setInsertionPoint(tileOp);
170+
SmallVector<int64_t> reshapeWithoutOneShape;
171+
reshapeWithoutOneShape.reserve(reshapeOutShape.size() - 1);
172+
for (auto [idx, dim] : llvm::enumerate(reshapeOutShape)) {
173+
if (idx != *firstAddedOneDim) {
174+
reshapeWithoutOneShape.push_back(dim);
175+
}
176+
}
177+
auto removeOneDimReshape = rewriter.createOrFold<tosa::ReshapeOp>(
178+
tileOp.getLoc(), producerReshape.getResult(), reshapeWithoutOneShape);
179+
SmallVector<int64_t> newTileMultiples;
180+
newTileMultiples.reserve(multiplies.size());
181+
for (auto [idx, multiplie] : llvm::enumerate(multiplies)) {
182+
if (idx == *firstAddedOneDim) {
183+
continue;
184+
}
185+
if (idx == *firstAddedOneDim + 1) {
186+
newTileMultiples.push_back(multiplie * multiplies[*firstAddedOneDim]);
187+
} else {
188+
newTileMultiples.push_back(multiplie);
189+
}
190+
}
191+
auto newTileConstMults =
192+
getTosaConstShape(rewriter, tileOp.getLoc(), newTileMultiples);
193+
SmallVector<int64_t> newTileResultShape;
194+
for (auto [dim, mult] :
195+
llvm::zip_equal(reshapeWithoutOneShape, newTileMultiples)) {
196+
newTileResultShape.push_back(dim * mult);
197+
}
198+
auto newTileResultType = tileResultType.clone(newTileResultShape);
199+
auto newTileOp = rewriter.create<tosa::TileOp>(
200+
tileOp.getLoc(), newTileResultType, removeOneDimReshape, newTileConstMults);
201+
auto reshapeToOriginalResultShape = rewriter.create<tosa::ReshapeOp>(
202+
tileOp.getLoc(), newTileOp.getResult(), originalTileShape);
203+
rewriter.replaceOp(tileOp, reshapeToOriginalResultShape.getResult());
204+
return success();
205+
}
206+
};
207+
112208
struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
113209
using OpRewritePattern<tosa::TileOp>::OpRewritePattern;
114210

@@ -155,6 +251,7 @@ struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
155251

156252
void TileOp::getCanonicalizationPatterns(RewritePatternSet &results,
157253
MLIRContext *context) {
254+
results.add<TileOnOneSizedDim>(context);
158255
results.add<FuseChainedTile>(context);
159256
}
160257

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
// Modifications (c) Copyright 2023 - 2025 Advanced Micro Devices, Inc. or its affiliates
12
// RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s
23

34
// CHECK-LABEL: @argmax_nofold
@@ -989,6 +990,76 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> {
989990

990991
// -----
991992

993+
994+
// CHECK-LABEL: func.func @tile_on_one_sized_dim_front
995+
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>) -> tensor<4x2x3xf32>
996+
// CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[4, 1]> : tensor<2xindex>} : () -> !tosa.shape<2>
997+
// CHECK: %[[TILE:.*]] = tosa.tile %[[ARG]], %[[CST]] : (tensor<2x3xf32>, !tosa.shape<2>) -> tensor<8x3xf32>
998+
// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[TILE]] {new_shape = array<i64: 4, 2, 3>} : (tensor<8x3xf32>) -> tensor<4x2x3xf32>
999+
// CHECK: return %[[RESHAPE]] : tensor<4x2x3xf32>
1000+
func.func @tile_on_one_sized_dim_front(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> {
1001+
%r = tosa.reshape %arg0 {new_shape = array<i64: 1, 2, 3>} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
1002+
%m = tosa.const_shape {value = dense<[4, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1003+
%t = tosa.tile %r, %m : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<4x2x3xf32>
1004+
return %t : tensor<4x2x3xf32>
1005+
}
1006+
1007+
// -----
1008+
1009+
// CHECK-LABEL: func.func @tile_on_one_sized_dim_middle
1010+
// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3x5xf32>) -> tensor<2x4x3x5xf32>
1011+
// CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1012+
// CHECK: %[[TILE:.*]] = tosa.tile %[[ARG]], %[[CST]] : (tensor<2x3x5xf32>, !tosa.shape<3>) -> tensor<2x12x5xf32>
1013+
// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[TILE]] {new_shape = array<i64: 2, 4, 3, 5>} : (tensor<2x12x5xf32>) -> tensor<2x4x3x5xf32>
1014+
// CHECK: return %[[RESHAPE]] : tensor<2x4x3x5xf32>
1015+
func.func @tile_on_one_sized_dim_middle(%arg0: tensor<2x3x5xf32>) -> tensor<2x4x3x5xf32> {
1016+
%r = tosa.reshape %arg0 {new_shape = array<i64: 2, 1, 3, 5>} : (tensor<2x3x5xf32>) -> tensor<2x1x3x5xf32>
1017+
%m = tosa.const_shape {value = dense<[1, 4, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
1018+
%t = tosa.tile %r, %m : (tensor<2x1x3x5xf32>, !tosa.shape<4>) -> tensor<2x4x3x5xf32>
1019+
return %t : tensor<2x4x3x5xf32>
1020+
}
1021+
1022+
// -----
1023+
1024+
// Negative: trailing added 1.
1025+
// CHECK-LABEL: func.func @tile_on_one_sized_dim_trailing_one_no_match
1026+
// CHECK: tosa.reshape %{{.*}} {new_shape = array<i64: 2, 3, 1>} : (tensor<2x3xf32>) -> tensor<2x3x1xf32>
1027+
// CHECK: tosa.tile %{{.*}} : (tensor<2x3x1xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
1028+
func.func @tile_on_one_sized_dim_trailing_one_no_match(%arg0: tensor<2x3xf32>) -> tensor<2x3x4xf32> {
1029+
%r = tosa.reshape %arg0 {new_shape = array<i64: 2, 3, 1>} : (tensor<2x3xf32>) -> tensor<2x3x1xf32>
1030+
%m = tosa.const_shape {value = dense<[1, 1, 4]> : tensor<3xindex>} : () -> !tosa.shape<3>
1031+
%t = tosa.tile %r, %m : (tensor<2x3x1xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
1032+
return %t : tensor<2x3x4xf32>
1033+
}
1034+
1035+
// -----
1036+
1037+
// Negative: multiplier on inserted 1-dim is 1.
1038+
// CHECK-LABEL: func.func @tile_on_one_sized_dim_multiplier_one_no_match
1039+
// CHECK: tosa.reshape %{{.*}} {new_shape = array<i64: 1, 2, 3>}
1040+
// CHECK: tosa.tile %{{.*}} : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<1x4x3xf32>
1041+
func.func @tile_on_one_sized_dim_multiplier_one_no_match(%arg0: tensor<2x3xf32>) -> tensor<1x4x3xf32> {
1042+
%r = tosa.reshape %arg0 {new_shape = array<i64: 1, 2, 3>} : (tensor<2x3xf32>) -> tensor<1x2x3xf32>
1043+
%m = tosa.const_shape {value = dense<[1, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1044+
%t = tosa.tile %r, %m : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<1x4x3xf32>
1045+
return %t : tensor<1x4x3xf32>
1046+
}
1047+
1048+
// -----
1049+
1050+
// Negative: reshape not only adding ones.
1051+
// CHECK-LABEL: func.func @tile_on_one_sized_dim_not_only_adding_ones_no_match
1052+
// CHECK: tosa.reshape %{{.*}} {new_shape = array<i64: 1, 3, 4>}
1053+
// CHECK: tosa.tile %{{.*}} : (tensor<1x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
1054+
func.func @tile_on_one_sized_dim_not_only_adding_ones_no_match(%arg0: tensor<2x6xf32>) -> tensor<2x3x4xf32> {
1055+
%r = tosa.reshape %arg0 {new_shape = array<i64: 1, 3, 4>} : (tensor<2x6xf32>) -> tensor<1x3x4xf32>
1056+
%m = tosa.const_shape {value = dense<[2, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1057+
%t = tosa.tile %r, %m : (tensor<1x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32>
1058+
return %t : tensor<2x3x4xf32>
1059+
}
1060+
1061+
// -----
1062+
9921063
// CHECK-LABEL: @transpose_no_op
9931064
func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> {
9941065
// CHECK: return %arg0

0 commit comments

Comments
 (0)