From 7276ee61662f100b31e1b93c7354c5b0d8230d37 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Wed, 3 Sep 2025 22:06:22 +0100 Subject: [PATCH] Rewrites reshapes that are adding 1-sized dims and are followed by a tile on the one-sized dim to a tile on the original shape followed by a reshape. This is done to reduce the rank of tile ops. --- .../Dialect/Tosa/IR/TosaCanonicalizations.cpp | 97 +++++++++++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 71 ++++++++++++++ 2 files changed, 168 insertions(+) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index ddc02993892b..a66d173a2686 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -109,6 +109,102 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, results.add(context); } +/* Rewrites reshapes that are adding 1-sized dims and are followed by a tile on + * the one-sized dim to a tile on the original shape followed by a reshape. This + * is done to reduce the rank of tile ops. */ +struct TileOnOneSizedDim : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tosa::TileOp tileOp, + PatternRewriter &rewriter) const override { + SmallVector multiplies; + if (failed(tileOp.getConstantMultiples(multiplies))) { + return rewriter.notifyMatchFailure(tileOp, "Requires const multiplies"); + } + auto tileResultType = dyn_cast(tileOp.getResult().getType()); + if (!tileResultType || !tileResultType.hasStaticShape()) { + return rewriter.notifyMatchFailure(tileOp, "Requires static shaped types"); + } + const auto originalTileShape = tileResultType.getShape(); + auto producerReshape = tileOp.getInput1().getDefiningOp(); + if (!producerReshape) { + return rewriter.notifyMatchFailure(tileOp, "Producer is not a reshape"); + } + + auto reshapeInputType = + dyn_cast(producerReshape->getOperand(0).getType()); + auto reshapeResultType = + dyn_cast(producerReshape->getResult(0).getType()); + if (!reshapeInputType || !reshapeResultType || + !reshapeInputType.hasStaticShape() || + !reshapeResultType.hasStaticShape()) { + return rewriter.notifyMatchFailure(tileOp, "Requires static shaped types"); + } + const auto reshapeInShape = reshapeInputType.getShape(); + const auto reshapeOutShape = reshapeResultType.getShape(); + std::optional firstAddedOneDim; + for (auto [idx, outDim] : llvm::enumerate(reshapeOutShape)) { + if (idx >= reshapeInShape.size()) { + return rewriter.notifyMatchFailure( + tileOp, "Did not find reshape just adding ones"); + } + if (outDim == reshapeInShape[idx]) { + continue; + } + if (outDim == 1 && idx + 1 < reshapeOutShape.size() && + reshapeOutShape[idx + 1] == reshapeInShape[idx]) { + firstAddedOneDim = idx; + break; + } + } + if (!firstAddedOneDim) { + return rewriter.notifyMatchFailure( + tileOp, "Producer reshape is not only adding one dims"); + } + if (multiplies[*firstAddedOneDim] == 1) { + return rewriter.notifyMatchFailure( + tileOp, "Tile is not on a one sized dimension"); + } + RewriterBase::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(tileOp); + SmallVector reshapeWithoutOneShape; + reshapeWithoutOneShape.reserve(reshapeOutShape.size() - 1); + for (auto [idx, dim] : llvm::enumerate(reshapeOutShape)) { + if (idx != *firstAddedOneDim) { + reshapeWithoutOneShape.push_back(dim); + } + } + auto removeOneDimReshape = rewriter.createOrFold( + tileOp.getLoc(), producerReshape.getResult(), reshapeWithoutOneShape); + SmallVector newTileMultiples; + newTileMultiples.reserve(multiplies.size()); + for (auto [idx, multiplie] : llvm::enumerate(multiplies)) { + if (idx == *firstAddedOneDim) { + continue; + } + if (idx == *firstAddedOneDim + 1) { + newTileMultiples.push_back(multiplie * multiplies[*firstAddedOneDim]); + } else { + newTileMultiples.push_back(multiplie); + } + } + auto newTileConstMults = + getTosaConstShape(rewriter, tileOp.getLoc(), newTileMultiples); + SmallVector newTileResultShape; + for (auto [dim, mult] : + llvm::zip_equal(reshapeWithoutOneShape, newTileMultiples)) { + newTileResultShape.push_back(dim * mult); + } + auto newTileResultType = tileResultType.clone(newTileResultShape); + auto newTileOp = rewriter.create( + tileOp.getLoc(), newTileResultType, removeOneDimReshape, newTileConstMults); + auto reshapeToOriginalResultShape = rewriter.create( + tileOp.getLoc(), newTileOp.getResult(), originalTileShape); + rewriter.replaceOp(tileOp, reshapeToOriginalResultShape.getResult()); + return success(); + } +}; + struct FuseChainedTile : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -155,6 +251,7 @@ struct FuseChainedTile : public OpRewritePattern { void TileOp::getCanonicalizationPatterns(RewritePatternSet &results, MLIRContext *context) { + results.add(context); results.add(context); } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 66c5904004c1..106e2cf16804 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -1,3 +1,4 @@ +// Modifications (c) Copyright 2023 - 2025 Advanced Micro Devices, Inc. or its affiliates // RUN: mlir-opt --split-input-file -canonicalize="test-convergence" %s | FileCheck %s // CHECK-LABEL: @argmax_nofold @@ -989,6 +990,76 @@ func.func @tile_nofold(%arg0: tensor<3x4xf32>) -> tensor<3x8xf32> { // ----- + +// CHECK-LABEL: func.func @tile_on_one_sized_dim_front +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3xf32>) -> tensor<4x2x3xf32> +// CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[4, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[TILE:.*]] = tosa.tile %[[ARG]], %[[CST]] : (tensor<2x3xf32>, !tosa.shape<2>) -> tensor<8x3xf32> +// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[TILE]] {new_shape = array} : (tensor<8x3xf32>) -> tensor<4x2x3xf32> +// CHECK: return %[[RESHAPE]] : tensor<4x2x3xf32> +func.func @tile_on_one_sized_dim_front(%arg0: tensor<2x3xf32>) -> tensor<4x2x3xf32> { + %r = tosa.reshape %arg0 {new_shape = array} : (tensor<2x3xf32>) -> tensor<1x2x3xf32> + %m = tosa.const_shape {value = dense<[4, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %t = tosa.tile %r, %m : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<4x2x3xf32> + return %t : tensor<4x2x3xf32> +} + +// ----- + +// CHECK-LABEL: func.func @tile_on_one_sized_dim_middle +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x3x5xf32>) -> tensor<2x4x3x5xf32> +// CHECK: %[[CST:.*]] = tosa.const_shape {value = dense<[1, 4, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[TILE:.*]] = tosa.tile %[[ARG]], %[[CST]] : (tensor<2x3x5xf32>, !tosa.shape<3>) -> tensor<2x12x5xf32> +// CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[TILE]] {new_shape = array} : (tensor<2x12x5xf32>) -> tensor<2x4x3x5xf32> +// CHECK: return %[[RESHAPE]] : tensor<2x4x3x5xf32> +func.func @tile_on_one_sized_dim_middle(%arg0: tensor<2x3x5xf32>) -> tensor<2x4x3x5xf32> { + %r = tosa.reshape %arg0 {new_shape = array} : (tensor<2x3x5xf32>) -> tensor<2x1x3x5xf32> + %m = tosa.const_shape {value = dense<[1, 4, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> + %t = tosa.tile %r, %m : (tensor<2x1x3x5xf32>, !tosa.shape<4>) -> tensor<2x4x3x5xf32> + return %t : tensor<2x4x3x5xf32> +} + +// ----- + +// Negative: trailing added 1. +// CHECK-LABEL: func.func @tile_on_one_sized_dim_trailing_one_no_match +// CHECK: tosa.reshape %{{.*}} {new_shape = array} : (tensor<2x3xf32>) -> tensor<2x3x1xf32> +// CHECK: tosa.tile %{{.*}} : (tensor<2x3x1xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32> +func.func @tile_on_one_sized_dim_trailing_one_no_match(%arg0: tensor<2x3xf32>) -> tensor<2x3x4xf32> { + %r = tosa.reshape %arg0 {new_shape = array} : (tensor<2x3xf32>) -> tensor<2x3x1xf32> + %m = tosa.const_shape {value = dense<[1, 1, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> + %t = tosa.tile %r, %m : (tensor<2x3x1xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32> + return %t : tensor<2x3x4xf32> +} + +// ----- + +// Negative: multiplier on inserted 1-dim is 1. +// CHECK-LABEL: func.func @tile_on_one_sized_dim_multiplier_one_no_match +// CHECK: tosa.reshape %{{.*}} {new_shape = array} +// CHECK: tosa.tile %{{.*}} : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<1x4x3xf32> +func.func @tile_on_one_sized_dim_multiplier_one_no_match(%arg0: tensor<2x3xf32>) -> tensor<1x4x3xf32> { + %r = tosa.reshape %arg0 {new_shape = array} : (tensor<2x3xf32>) -> tensor<1x2x3xf32> + %m = tosa.const_shape {value = dense<[1, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %t = tosa.tile %r, %m : (tensor<1x2x3xf32>, !tosa.shape<3>) -> tensor<1x4x3xf32> + return %t : tensor<1x4x3xf32> +} + +// ----- + +// Negative: reshape not only adding ones. +// CHECK-LABEL: func.func @tile_on_one_sized_dim_not_only_adding_ones_no_match +// CHECK: tosa.reshape %{{.*}} {new_shape = array} +// CHECK: tosa.tile %{{.*}} : (tensor<1x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32> +func.func @tile_on_one_sized_dim_not_only_adding_ones_no_match(%arg0: tensor<2x6xf32>) -> tensor<2x3x4xf32> { + %r = tosa.reshape %arg0 {new_shape = array} : (tensor<2x6xf32>) -> tensor<1x3x4xf32> + %m = tosa.const_shape {value = dense<[2, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> + %t = tosa.tile %r, %m : (tensor<1x3x4xf32>, !tosa.shape<3>) -> tensor<2x3x4xf32> + return %t : tensor<2x3x4xf32> +} + +// ----- + // CHECK-LABEL: @transpose_no_op func.func @transpose_no_op(%arg0: tensor<3x4x5x6xf32>) -> tensor<3x4x5x6xf32> { // CHECK: return %arg0