From 1554619817552db5a6e316e1a8176ffd2e63edcd Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Tue, 7 May 2024 16:27:53 +0200 Subject: [PATCH 1/2] TosaToTensor: Support reshape on unsigned --- .../mlir/Conversion/TosaToTensor/TosaToTensor.h | 4 +++- mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp | 12 ++++++++---- .../Conversion/TosaToTensor/TosaToTensorPass.cpp | 6 +++++- .../Conversion/TosaToTensor/tosa-to-tensor.mlir | 13 +++++++++++++ 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h index 3953c83f3aa10..76a4b1b156336 100644 --- a/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h +++ b/mlir/include/mlir/Conversion/TosaToTensor/TosaToTensor.h @@ -16,6 +16,7 @@ #include "mlir/Pass/Pass.h" namespace mlir { +class TypeConverter; #define GEN_PASS_DECL_TOSATOTENSOR #include "mlir/Conversion/Passes.h.inc" @@ -24,7 +25,8 @@ namespace tosa { std::unique_ptr createTosaToTensor(); -void populateTosaToTensorConversionPatterns(RewritePatternSet *patterns); +void populateTosaToTensorConversionPatterns(TypeConverter &converter, + RewritePatternSet *patterns); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp index 505d85f211111..8d3ac626f5d8c 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensor.cpp @@ -207,7 +207,11 @@ class ReshapeConverterCollapseExpand matchAndRewrite(tosa::ReshapeOp reshape, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const final { ShapedType operandTy = cast(adaptor.getInput1().getType()); - ShapedType resultTy = cast(reshape.getType()); + ShapedType resultTy = cast_if_present(getTypeConverter()->convertType(reshape.getType())); + if (!resultTy) { + return rewriter.notifyMatchFailure( + reshape.getLoc(), "could not convert result type"); + } bool isDynamic = !operandTy.hasStaticShape(); SmallVector intermediateShape; @@ -218,7 +222,7 @@ class ReshapeConverterCollapseExpand "the given two shapes"); } auto intermediateTy = RankedTensorType::get( - intermediateShape, reshape.getType().getElementType()); + intermediateShape, resultTy.getElementType()); Value collapse = createCollapse(rewriter, reshape.getLoc(), intermediateTy, adaptor.getInput1()); @@ -415,9 +419,9 @@ struct ConcatConverter : public OpConversionPattern { } // namespace void mlir::tosa::populateTosaToTensorConversionPatterns( - RewritePatternSet *patterns) { + TypeConverter &converter, RewritePatternSet *patterns) { patterns->add( patterns->getContext()); - patterns->add(patterns->getContext()); + patterns->add(converter, patterns->getContext()); } diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp index 50dc55667fb94..9ae5edcce291e 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp @@ -20,6 +20,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include namespace mlir { #define GEN_PASS_DEF_TOSATOTENSOR @@ -42,7 +43,10 @@ struct TosaToTensor : public impl::TosaToTensorBase { target.addLegalDialect(); target.addLegalDialect(); - mlir::tosa::populateTosaToTensorConversionPatterns(&patterns); + TypeConverter converter; + mlir::tosa::populateTosaToLinalgTypeConversion(converter); + + mlir::tosa::populateTosaToTensorConversionPatterns(converter, &patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index ea121565d96dc..455c1303df037 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -56,6 +56,19 @@ func.func @test_reshape_samerank(%arg0: tensor<3x2xf32>) -> tensor<2x3xf32> { // ----- +// CHECK-LABEL: @test_reshape_samerank_unsigned +// CHECK-SAME: (%[[ARG0:.*]]: tensor<3x2xui8>) +func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3xui8> { + // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8> + // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8> + // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] : tensor<6xi8> into tensor<2x3xi8> + // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8 + %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2xui8>) -> tensor<2x3xui8> + // CHECK-NEXT: return %[[CAST2]] + return %0 : tensor<2x3xui8> +} +// ----- + // CHECK-LABEL: @test_reshape_samerank_dyn // CHECK-SAME: (%[[ARG0:.*]]: tensor) func.func @test_reshape_samerank_dyn(%arg0: tensor) -> tensor<2x?xf32> { From d87d72016be698ed21c26af3b6e2be759863749f Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 8 May 2024 16:44:15 +0200 Subject: [PATCH 2/2] comments --- mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp | 3 ++- mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp index 9ae5edcce291e..31049777323be 100644 --- a/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp +++ b/mlir/lib/Conversion/TosaToTensor/TosaToTensorPass.cpp @@ -12,6 +12,7 @@ #include "mlir/Conversion/TosaToTensor/TosaToTensor.h" +#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" @@ -20,7 +21,7 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include + namespace mlir { #define GEN_PASS_DEF_TOSATOTENSOR diff --git a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir index 455c1303df037..ce310f13c4c19 100644 --- a/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir +++ b/mlir/test/Conversion/TosaToTensor/tosa-to-tensor.mlir @@ -62,7 +62,7 @@ func.func @test_reshape_samerank_unsigned(%arg0: tensor<3x2xui8>) -> tensor<2x3x // CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : tensor<3x2xui8> to tensor<3x2xi8> // CHECK-NEXT: %[[RESHAPE1:.*]] = tensor.collapse_shape %[[CAST1]] {{\[}}[0, 1]] : tensor<3x2xi8> into tensor<6xi8> // CHECK-NEXT: %[[RESHAPE2:.*]] = tensor.expand_shape %[[RESHAPE1]] {{\[}}[0, 1]] : tensor<6xi8> into tensor<2x3xi8> - // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8 + // CHECK-NEXT: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[RESHAPE2]] : tensor<2x3xi8> to tensor<2x3xui8> %0 = "tosa.reshape"(%arg0) {new_shape = array} : (tensor<3x2xui8>) -> tensor<2x3xui8> // CHECK-NEXT: return %[[CAST2]] return %0 : tensor<2x3xui8>