@@ -109,6 +109,102 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
109
109
results.add <SelfConcatToTile>(context);
110
110
}
111
111
112
+ /* Rewrites reshapes that are adding 1-sized dims and are followed by a tile on
113
+ * the one-sized dim to a tile on the original shape followed by a reshape. This
114
+ * is done to reduce the rank of tile ops. */
115
+ struct TileOnOneSizedDim : public OpRewritePattern <tosa::TileOp> {
116
+ using OpRewritePattern<tosa::TileOp>::OpRewritePattern;
117
+
118
+ LogicalResult matchAndRewrite (tosa::TileOp tileOp,
119
+ PatternRewriter &rewriter) const override {
120
+ SmallVector<int64_t > multiplies;
121
+ if (failed (tileOp.getConstantMultiples (multiplies))) {
122
+ return rewriter.notifyMatchFailure (tileOp, " Requires const multiplies" );
123
+ }
124
+ auto tileResultType = dyn_cast<ShapedType>(tileOp.getResult ().getType ());
125
+ if (!tileResultType || !tileResultType.hasStaticShape ()) {
126
+ return rewriter.notifyMatchFailure (tileOp, " Requires static shaped types" );
127
+ }
128
+ const auto originalTileShape = tileResultType.getShape ();
129
+ auto producerReshape = tileOp.getInput1 ().getDefiningOp <ReshapeOp>();
130
+ if (!producerReshape) {
131
+ return rewriter.notifyMatchFailure (tileOp, " Producer is not a reshape" );
132
+ }
133
+
134
+ auto reshapeInputType =
135
+ dyn_cast<ShapedType>(producerReshape->getOperand (0 ).getType ());
136
+ auto reshapeResultType =
137
+ dyn_cast<ShapedType>(producerReshape->getResult (0 ).getType ());
138
+ if (!reshapeInputType || !reshapeResultType ||
139
+ !reshapeInputType.hasStaticShape () ||
140
+ !reshapeResultType.hasStaticShape ()) {
141
+ return rewriter.notifyMatchFailure (tileOp, " Requires static shaped types" );
142
+ }
143
+ const auto reshapeInShape = reshapeInputType.getShape ();
144
+ const auto reshapeOutShape = reshapeResultType.getShape ();
145
+ std::optional<size_t > firstAddedOneDim;
146
+ for (auto [idx, outDim] : llvm::enumerate (reshapeOutShape)) {
147
+ if (idx >= reshapeInShape.size ()) {
148
+ return rewriter.notifyMatchFailure (
149
+ tileOp, " Did not find reshape just adding ones" );
150
+ }
151
+ if (outDim == reshapeInShape[idx]) {
152
+ continue ;
153
+ }
154
+ if (outDim == 1 && idx + 1 < reshapeOutShape.size () &&
155
+ reshapeOutShape[idx + 1 ] == reshapeInShape[idx]) {
156
+ firstAddedOneDim = idx;
157
+ break ;
158
+ }
159
+ }
160
+ if (!firstAddedOneDim) {
161
+ return rewriter.notifyMatchFailure (
162
+ tileOp, " Producer reshape is not only adding one dims" );
163
+ }
164
+ if (multiplies[*firstAddedOneDim] == 1 ) {
165
+ return rewriter.notifyMatchFailure (
166
+ tileOp, " Tile is not on a one sized dimension" );
167
+ }
168
+ RewriterBase::InsertionGuard guard (rewriter);
169
+ rewriter.setInsertionPoint (tileOp);
170
+ SmallVector<int64_t > reshapeWithoutOneShape;
171
+ reshapeWithoutOneShape.reserve (reshapeOutShape.size () - 1 );
172
+ for (auto [idx, dim] : llvm::enumerate (reshapeOutShape)) {
173
+ if (idx != *firstAddedOneDim) {
174
+ reshapeWithoutOneShape.push_back (dim);
175
+ }
176
+ }
177
+ auto removeOneDimReshape = rewriter.createOrFold <tosa::ReshapeOp>(
178
+ tileOp.getLoc (), producerReshape.getResult (), reshapeWithoutOneShape);
179
+ SmallVector<int64_t > newTileMultiples;
180
+ newTileMultiples.reserve (multiplies.size ());
181
+ for (auto [idx, multiplie] : llvm::enumerate (multiplies)) {
182
+ if (idx == *firstAddedOneDim) {
183
+ continue ;
184
+ }
185
+ if (idx == *firstAddedOneDim + 1 ) {
186
+ newTileMultiples.push_back (multiplie * multiplies[*firstAddedOneDim]);
187
+ } else {
188
+ newTileMultiples.push_back (multiplie);
189
+ }
190
+ }
191
+ auto newTileConstMults =
192
+ getTosaConstShape (rewriter, tileOp.getLoc (), newTileMultiples);
193
+ SmallVector<int64_t > newTileResultShape;
194
+ for (auto [dim, mult] :
195
+ llvm::zip_equal (reshapeWithoutOneShape, newTileMultiples)) {
196
+ newTileResultShape.push_back (dim * mult);
197
+ }
198
+ auto newTileResultType = tileResultType.clone (newTileResultShape);
199
+ auto newTileOp = rewriter.create <tosa::TileOp>(
200
+ tileOp.getLoc (), newTileResultType, removeOneDimReshape, newTileConstMults);
201
+ auto reshapeToOriginalResultShape = rewriter.create <tosa::ReshapeOp>(
202
+ tileOp.getLoc (), newTileOp.getResult (), originalTileShape);
203
+ rewriter.replaceOp (tileOp, reshapeToOriginalResultShape.getResult ());
204
+ return success ();
205
+ }
206
+ };
207
+
112
208
struct FuseChainedTile : public OpRewritePattern <tosa::TileOp> {
113
209
using OpRewritePattern<tosa::TileOp>::OpRewritePattern;
114
210
@@ -155,6 +251,7 @@ struct FuseChainedTile : public OpRewritePattern<tosa::TileOp> {
155
251
156
252
void TileOp::getCanonicalizationPatterns (RewritePatternSet &results,
157
253
MLIRContext *context) {
254
+ results.add <TileOnOneSizedDim>(context);
158
255
results.add <FuseChainedTile>(context);
159
256
}
160
257
0 commit comments