@@ -109,6 +109,102 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
109109 results.add <SelfConcatToTile>(context);
110110}
111111
112+ /* Rewrites reshapes that are adding 1-sized dims and are followed by a tile on
113+ * the one-sized dim to a tile on the original shape followed by a reshape. This
114+ * is done to reduce the rank of tile ops. */
115+ struct TileOnOneSizedDim : public OpRewritePattern <tosa::TileOp> {
116+ using OpRewritePattern<tosa::TileOp>::OpRewritePattern;
117+
118+ LogicalResult matchAndRewrite (tosa::TileOp tileOp,
119+ PatternRewriter &rewriter) const override {
120+ SmallVector<int64_t > multiplies;
121+ if (failed (tileOp.getConstantMultiples (multiplies))) {
122+ return rewriter.notifyMatchFailure (tileOp, " Requires const multiplies" );
123+ }
124+ auto tileResultType = dyn_cast<ShapedType>(tileOp.getResult ().getType ());
125+ if (!tileResultType || !tileResultType.hasStaticShape ()) {
126+ return rewriter.notifyMatchFailure (tileOp, " Requires static shaped types" );
127+ }
128+ const auto originalTileShape = tileResultType.getShape ();
129+ auto producerReshape = tileOp.getInput1 ().getDefiningOp <ReshapeOp>();
130+ if (!producerReshape) {
131+ return rewriter.notifyMatchFailure (tileOp, " Producer is not a reshape" );
132+ }
133+
134+ auto reshapeInputType =
135+ dyn_cast<ShapedType>(producerReshape->getOperand (0 ).getType ());
136+ auto reshapeResultType =
137+ dyn_cast<ShapedType>(producerReshape->getResult (0 ).getType ());
138+ if (!reshapeInputType || !reshapeResultType ||
139+ !reshapeInputType.hasStaticShape () ||
140+ !reshapeResultType.hasStaticShape ()) {
141+ return rewriter.notifyMatchFailure (tileOp, " Requires static shaped types" );
142+ }
143+ const auto reshapeInShape = reshapeInputType.getShape ();
144+ const auto reshapeOutShape = reshapeResultType.getShape ();
145+ std::optional<size_t > firstAddedOneDim;
146+ for (auto [idx, outDim] : llvm::enumerate (reshapeOutShape)) {
147+ if (idx >= reshapeInShape.size ()) {
148+ return rewriter.notifyMatchFailure (
149+ tileOp, " Did not find reshape just adding ones" );
150+ }
151+ if (outDim == reshapeInShape[idx]) {
152+ continue ;
153+ }
154+ if (outDim == 1 && idx + 1 < reshapeOutShape.size () &&
155+ reshapeOutShape[idx + 1 ] == reshapeInShape[idx]) {
156+ firstAddedOneDim = idx;
157+ break ;
158+ }
159+ }
160+ if (!firstAddedOneDim) {
161+ return rewriter.notifyMatchFailure (
162+ tileOp, " Producer reshape is not only adding one dims" );
163+ }
164+ if (multiplies[*firstAddedOneDim] == 1 ) {
165+ return rewriter.notifyMatchFailure (
166+ tileOp, " Tile is not on a one sized dimension" );
167+ }
168+ RewriterBase::InsertionGuard guard (rewriter);
169+ rewriter.setInsertionPoint (tileOp);
170+ SmallVector<int64_t > reshapeWithoutOneShape;
171+ reshapeWithoutOneShape.reserve (reshapeOutShape.size () - 1 );
172+ for (auto [idx, dim] : llvm::enumerate (reshapeOutShape)) {
173+ if (idx != *firstAddedOneDim) {
174+ reshapeWithoutOneShape.push_back (dim);
175+ }
176+ }
177+ auto removeOneDimReshape = rewriter.createOrFold <tosa::ReshapeOp>(
178+ tileOp.getLoc (), producerReshape.getResult (), reshapeWithoutOneShape);
179+ SmallVector<int64_t > newTileMultiples;
180+ newTileMultiples.reserve (multiplies.size ());
181+ for (auto [idx, multiplie] : llvm::enumerate (multiplies)) {
182+ if (idx == *firstAddedOneDim) {
183+ continue ;
184+ }
185+ if (idx == *firstAddedOneDim + 1 ) {
186+ newTileMultiples.push_back (multiplie * multiplies[*firstAddedOneDim]);
187+ } else {
188+ newTileMultiples.push_back (multiplie);
189+ }
190+ }
191+ auto newTileConstMults =
192+ getTosaConstShape (rewriter, tileOp.getLoc (), newTileMultiples);
193+ SmallVector<int64_t > newTileResultShape;
194+ for (auto [dim, mult] :
195+ llvm::zip_equal (reshapeWithoutOneShape, newTileMultiples)) {
196+ newTileResultShape.push_back (dim * mult);
197+ }
198+ auto newTileResultType = tileResultType.clone (newTileResultShape);
199+ auto newTileOp = rewriter.create <tosa::TileOp>(
200+ tileOp.getLoc (), newTileResultType, removeOneDimReshape, newTileConstMults);
201+ auto reshapeToOriginalResultShape = rewriter.create <tosa::ReshapeOp>(
202+ tileOp.getLoc (), newTileOp.getResult (), originalTileShape);
203+ rewriter.replaceOp (tileOp, reshapeToOriginalResultShape.getResult ());
204+ return success ();
205+ }
206+ };
207+
112208struct FuseChainedTile : public OpRewritePattern <tosa::TileOp> {
113209 using OpRewritePattern<tosa::TileOp>::OpRewritePattern;
114210
@@ -155,6 +251,7 @@ struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
155251
156252void TileOp::getCanonicalizationPatterns (RewritePatternSet &results,
157253 MLIRContext *context) {
254+ results.add <TileOnOneSizedDim>(context);
158255 results.add <FuseChainedTile>(context);
159256}
160257
0 commit comments