Skip to content

Commit b8ee462

Browse files
committed
Rewrites reshapes that are adding 1-sized dims and are followed by a tile on the one-sized dim to a tile on the original shape followed by a reshape.
This is done to reduce the rank of tile ops.
1 parent e851bbc commit b8ee462

File tree

1 file changed

+97
-0
lines changed

1 file changed

+97
-0
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,102 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
109109
results.add<SelfConcatToTile>(context);
110110
}
111111

112+
/* Rewrites reshapes that are adding 1-sized dims and are followed by a tile on
113+
* the one-sized dim to a tile on the original shape followed by a reshape. This
114+
* is done to reduce the rank of tile ops. */
115+
struct TileOnOneSizedDim : public OpRewritePattern<tosa::TileOp> {
116+
using OpRewritePattern<tosa::TileOp>::OpRewritePattern;
117+
118+
LogicalResult matchAndRewrite(tosa::TileOp tileOp,
119+
PatternRewriter &rewriter) const override {
120+
SmallVector<int64_t> multiplies;
121+
if (failed(tileOp.getConstantMultiples(multiplies))) {
122+
return rewriter.notifyMatchFailure(tileOp, "Requires const multiplies");
123+
}
124+
auto tileResultType = dyn_cast<ShapedType>(tileOp.getResult().getType());
125+
if (!tileResultType || !tileResultType.hasStaticShape()) {
126+
return rewriter.notifyMatchFailure(tileOp, "Requires static shaped types");
127+
}
128+
const auto originalTileShape = tileResultType.getShape();
129+
auto producerReshape = tileOp.getInput1().getDefiningOp<ReshapeOp>();
130+
if (!producerReshape) {
131+
return rewriter.notifyMatchFailure(tileOp, "Producer is not a reshape");
132+
}
133+
134+
auto reshapeInputType =
135+
dyn_cast<ShapedType>(producerReshape->getOperand(0).getType());
136+
auto reshapeResultType =
137+
dyn_cast<ShapedType>(producerReshape->getResult(0).getType());
138+
if (!reshapeInputType || !reshapeResultType ||
139+
!reshapeInputType.hasStaticShape() ||
140+
!reshapeResultType.hasStaticShape()) {
141+
return rewriter.notifyMatchFailure(tileOp, "Requires static shaped types");
142+
}
143+
const auto reshapeInShape = reshapeInputType.getShape();
144+
const auto reshapeOutShape = reshapeResultType.getShape();
145+
std::optional<size_t> firstAddedOneDim;
146+
for (auto [idx, outDim] : llvm::enumerate(reshapeOutShape)) {
147+
if (idx >= reshapeInShape.size()) {
148+
return rewriter.notifyMatchFailure(
149+
tileOp, "Did not find reshape just adding ones");
150+
}
151+
if (outDim == reshapeInShape[idx]) {
152+
continue;
153+
}
154+
if (outDim == 1 && idx + 1 < reshapeOutShape.size() &&
155+
reshapeOutShape[idx + 1] == reshapeInShape[idx]) {
156+
firstAddedOneDim = idx;
157+
break;
158+
}
159+
}
160+
if (!firstAddedOneDim) {
161+
return rewriter.notifyMatchFailure(
162+
tileOp, "Producer reshape is not only adding one dims");
163+
}
164+
if (multiplies[*firstAddedOneDim] == 1) {
165+
return rewriter.notifyMatchFailure(
166+
tileOp, "Tile is not on a one sized dimension");
167+
}
168+
RewriterBase::InsertionGuard guard(rewriter);
169+
rewriter.setInsertionPoint(tileOp);
170+
SmallVector<int64_t> reshapeWithoutOneShape;
171+
reshapeWithoutOneShape.reserve(reshapeOutShape.size() - 1);
172+
for (auto [idx, dim] : llvm::enumerate(reshapeOutShape)) {
173+
if (idx != *firstAddedOneDim) {
174+
reshapeWithoutOneShape.push_back(dim);
175+
}
176+
}
177+
auto removeOneDimReshape = rewriter.createOrFold<tosa::ReshapeOp>(
178+
tileOp.getLoc(), producerReshape.getResult(), reshapeWithoutOneShape);
179+
SmallVector<int64_t> newTileMultiples;
180+
newTileMultiples.reserve(multiplies.size());
181+
for (auto [idx, multiplie] : llvm::enumerate(multiplies)) {
182+
if (idx == *firstAddedOneDim) {
183+
continue;
184+
}
185+
if (idx == *firstAddedOneDim + 1) {
186+
newTileMultiples.push_back(multiplie * multiplies[*firstAddedOneDim]);
187+
} else {
188+
newTileMultiples.push_back(multiplie);
189+
}
190+
}
191+
auto newTileConstMults =
192+
getTosaConstShape(rewriter, tileOp.getLoc(), newTileMultiples);
193+
SmallVector<int64_t> newTileResultShape;
194+
for (auto [dim, mult] :
195+
llvm::zip_equal(reshapeWithoutOneShape, newTileMultiples)) {
196+
newTileResultShape.push_back(dim * mult);
197+
}
198+
auto newTileResultType = tileResultType.clone(newTileResultShape);
199+
auto newTileOp = rewriter.create<tosa::TileOp>(
200+
tileOp.getLoc(), newTileResultType, removeOneDimReshape, newTileConstMults);
201+
auto reshapeToOriginalResultShape = rewriter.create<tosa::ReshapeOp>(
202+
tileOp.getLoc(), newTileOp.getResult(), originalTileShape);
203+
rewriter.replaceOp(tileOp, reshapeToOriginalResultShape.getResult());
204+
return success();
205+
}
206+
};
207+
112208
struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
113209
using OpRewritePattern<tosa::TileOp>::OpRewritePattern;
114210

@@ -155,6 +251,7 @@ struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
155251

156252
void TileOp::getCanonicalizationPatterns(RewritePatternSet &results,
157253
MLIRContext *context) {
254+
results.add<TileOnOneSizedDim>(context);
158255
results.add<FuseChainedTile>(context);
159256
}
160257

0 commit comments

Comments
 (0)