From 1891fff21472f41cd058855c01099a4dd6bf36fa Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 8 May 2024 11:03:50 +0100 Subject: [PATCH 01/17] Allow fp8 --- src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index b4fd9ab4d2..d47afc5014 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -82,6 +82,11 @@ void populateONNXToTOSAConversionPattern(ConversionTarget &target, target, patterns, typeConverter, ctx); } +inline bool isFloat8(mlir::Type type) { + return type.isa() && + type.cast().getWidth() == 8; +} + // Performs lowering to TOSA dialect struct FrontendToTosaLoweringPass : public PassWrapper> { @@ -120,8 +125,8 @@ void FrontendToTosaLoweringPass::runOnOperation() { // conversion failures. Quantized types are not supported right now. TypeConverter typeConverter; typeConverter.addConversion([](Type type) -> std::optional { - if (isTOSAInt(type) || isTOSAFloat(type) || type.isa() || - isTOSABool(type)) + if (isTOSAInt(type) || isTOSAFloat(type) || isFloat8(type) || + type.isa() || isTOSABool(type)) return type; return std::nullopt; }); From cafc2e0a3c93441effea7155e89cdd5e4623d2d1 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 8 May 2024 13:37:41 +0100 Subject: [PATCH 02/17] Add test --- .../conversion/onnx_to_tosa/Flow/fp8.mlir | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir diff --git a/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir b/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir new file mode 100644 index 0000000000..4107df7a4a --- /dev/null +++ b/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir @@ -0,0 +1,36 @@ +// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s + +func.func @test_f8E4M3FNUZ(%arg0: tensor<13x21x3xf8E4M3FNUZ>) -> tensor<13x21x3xf8E4M3FNUZ> { + func.return %arg0 : tensor<13x21x3xf8E4M3FNUZ> +} +// CHECK: @test_f8E4M3FNUZ(%[[arg0:.*]]: tensor<13x21x3xf8E4M3FNUZ>) -> tensor<13x21x3xf8E4M3FNUZ> +// CHECK-NEXT: return %[[arg0]] : tensor<13x21x3xf8E4M3FNUZ> + + +func.func @test_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> { + func.return %arg0 : tensor<13x21x3xf8E4M3FN> +} +// CHECK: @test_f8E4M3FN(%[[arg0:.*]]: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> +// CHECK-NEXT: return %[[arg0]] : tensor<13x21x3xf8E4M3FN> + +func.func @test_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> { + func.return %arg0 : tensor<13x21x3xf8E5M2> +} + +// CHECK: func.func @test_f8E5M2(%[[arg0:.*]]: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> +// CHECK: return %[[arg0]] : tensor<13x21x3xf8E5M2> + +func.func @test_f8E5M2FNUZ(%arg0: tensor<13x21x3xf8E5M2FNUZ>) -> tensor<13x21x3xf8E5M2FNUZ> { + func.return %arg0 : tensor<13x21x3xf8E5M2FNUZ> +} + +// CHECK: @test_f8E5M2FNUZ(%[[arg0:.*]]: tensor<13x21x3xf8E5M2FNUZ>) -> tensor<13x21x3xf8E5M2FNUZ> { +// CHECK-NEXT: return %[[arg0]] : tensor<13x21x3xf8E5M2FNUZ> + +func.func @test_f8E4M3B11FNUZ(%arg0: tensor<13x21x3xf8E4M3B11FNUZ>) -> tensor<13x21x3xf8E4M3B11FNUZ> { + func.return %arg0 : tensor<13x21x3xf8E4M3B11FNUZ> +} + +// CHECK: @test_f8E4M3B11FNUZ(%[[arg0:.*]]: tensor<13x21x3xf8E4M3B11FNUZ>) -> tensor<13x21x3xf8E4M3B11FNUZ> { +// CHECK-NEXT: return %[[arg0]] : tensor<13x21x3xf8E4M3B11FNUZ> + From ff843445829a2bf3d528f3d599d548cf3bff2386 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 8 May 2024 14:03:52 +0100 Subject: [PATCH 03/17] Address comments --- src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp | 9 ++------- test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir | 2 +- 2 files changed, 3 insertions(+), 8 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index d47afc5014..c9c21724ec 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -82,11 +82,6 @@ void populateONNXToTOSAConversionPattern(ConversionTarget &target, target, patterns, typeConverter, ctx); } -inline bool isFloat8(mlir::Type type) { - return type.isa() && - type.cast().getWidth() == 8; -} - // Performs lowering to TOSA dialect struct FrontendToTosaLoweringPass : public PassWrapper> { @@ -125,8 +120,8 @@ void FrontendToTosaLoweringPass::runOnOperation() { // conversion failures. Quantized types are not supported right now. TypeConverter typeConverter; typeConverter.addConversion([](Type type) -> std::optional { - if (isTOSAInt(type) || isTOSAFloat(type) || isFloat8(type) || - type.isa() || isTOSABool(type)) + if (isTOSAInt(type) || type.isa() || type.isa() || + isTOSABool(type)) return type; return std::nullopt; }); diff --git a/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir b/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir index 4107df7a4a..e4d97b845c 100644 --- a/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir @@ -1,4 +1,4 @@ -// RUN: onnx-mlir-opt --shape-inference --convert-onnx-to-tosa %s -split-input-file | FileCheck %s +// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s func.func @test_f8E4M3FNUZ(%arg0: tensor<13x21x3xf8E4M3FNUZ>) -> tensor<13x21x3xf8E4M3FNUZ> { func.return %arg0 : tensor<13x21x3xf8E4M3FNUZ> From 57b081cf3438791bd403c09d59b9c577e33bde6b Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 8 May 2024 14:09:53 +0100 Subject: [PATCH 04/17] Remove all ocurrences of isTOSAFloat --- src/Conversion/ONNXToTOSA/Math/Elementwise.cpp | 6 +++--- src/Conversion/ONNXToTOSA/Tensor/Resize.cpp | 3 ++- src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index e882e85713..115e067e96 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -60,7 +60,7 @@ struct ErfIOSupportedTypes { struct IsAnyLegalType { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!isTOSAFloat(scalarType) && !isTOSAInt(scalarType) && + if (!scalarType.isa() && !isTOSAInt(scalarType) && !isTOSABool(scalarType)) { return rewriter.notifyMatchFailure( op, "this operation only supports signed integer or float types"); @@ -72,7 +72,7 @@ struct IsAnyLegalType { struct IsIntOrFloat { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!isTOSAFloat(scalarType) && !isTOSAInt(scalarType)) { + if (!scalarType.isa() && !isTOSAInt(scalarType)) { return rewriter.notifyMatchFailure( op, "this operation only supports signed integer or float types"); } @@ -94,7 +94,7 @@ struct IsInt { struct IsFloat { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!isTOSAFloat(scalarType)) { + if (!scalarType.isa()) { return rewriter.notifyMatchFailure( op, "this operation only supports float types"); } diff --git a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp index 1df4b90afa..54341c1be3 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp @@ -19,6 +19,7 @@ #include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" #include +#include #include using namespace mlir; @@ -203,7 +204,7 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern { } auto elementType = inputType.getElementType(); - if (!(isTOSAFloat(elementType) || isTOSAInt(elementType))) { + if (!(elementType.isa() || isTOSAInt(elementType))) { return rewriter.notifyMatchFailure( resizeOp, "Element type is not supported by TOSA."); } diff --git a/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp b/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp index 9160b82c40..0b333b0e2c 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp @@ -53,7 +53,7 @@ class ONNXTransposeLoweringToTOSA Type inputElementType = inputType.getElementType(); - if (!isTOSAFloat(inputElementType) && !isTOSAInt(inputElementType) && + if (!inputElementType.isa() && !isTOSAInt(inputElementType) && !inputElementType.isInteger(1)) { return rewriter.notifyMatchFailure( op, "input element type not supported"); From a926028f2acd2d7ae42026825540abe4c94d1f42 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 8 May 2024 15:13:38 +0100 Subject: [PATCH 05/17] Fix Resize.mlir --- src/Dialect/ONNX/ONNXOps.td.inc | 4 ++-- .../conversion/onnx_to_tosa/NN/BatchNorm.mlir | 22 ++++++++++++++++--- .../onnx_to_tosa/Tensor/Resize.mlir | 18 +++++++++++++++ 3 files changed, 39 insertions(+), 5 deletions(-) diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index 35512ffe73..a3ebf40fba 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -7241,7 +7241,7 @@ def ONNXResizeOp:ONNX_Op<"Resize", ``` if input \\"sizes\\" is not specified. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[AnyFloat]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$roi, AnyTypeOf<[TensorOf<[F32]>, NoneType]>:$scales, AnyTypeOf<[TensorOf<[I64]>, NoneType]>:$sizes, @@ -7254,7 +7254,7 @@ def ONNXResizeOp:ONNX_Op<"Resize", DefaultValuedStrAttr:$keep_aspect_ratio_policy, DefaultValuedStrAttr:$mode, DefaultValuedStrAttr:$nearest_mode); - let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[AnyFloat]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 4; diff --git a/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir b/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir index 64a076d268..f28e8d2a20 100644 --- a/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir +++ b/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir @@ -85,7 +85,7 @@ func.func @test_batchnorm_bf16_dynamic(%arg0: tensor<100x3x?x?xbf16>) -> tensor< } // ----- -// tosa doesn't support f64, so it should not be lowered +// tosa doesn't support f64, but it is still lowered func.func @test_batchnorm_f64(%arg0: tensor<100x3x10x10xf64>) -> tensor<100x3x10x10xf64> { %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf64>} : () -> tensor<3xf64> @@ -94,6 +94,22 @@ func.func @test_batchnorm_f64(%arg0: tensor<100x3x10x10xf64>) -> tensor<100x3x10 %3 = "onnx.Constant"() {value = dense<[4.0, 5.0, 6.0]> : tensor<3xf64>} : () -> tensor<3xf64> %4 = "onnx.BatchNormalizationInferenceMode"(%arg0, %0, %1, %2, %3) {epsilon = 1.00000007E-5 : f32} : (tensor<100x3x10x10xf64>, tensor<3xf64>, tensor<3xf64>, tensor<3xf64>, tensor<3xf64>) -> tensor<100x3x10x10xf64> return %4 : tensor<100x3x10x10xf64> -// CHECK: onnx.BatchNormalizationInferenceMode -// CHECK-NOT: tosa +// CHECK-LABEL: @test_batchnorm_f64 +// CHECK-SAME: ([[PARAM_0:%.*]]: tensor<100x3x10x10xf64>) -> tensor<100x3x10x10xf64> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[1.000000e+00, 2.000000e+00, 3.000000e+00]> : tensor<3xf64>}> : () -> tensor<3xf64> +// CHECK-NEXT: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<[2.000000e+00, 3.000000e+00, 4.000000e+00]> : tensor<3xf64>}> : () -> tensor<3xf64> +// CHECK-NEXT: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<[3.000000e+00, 4.000000e+00, 5.000000e+00]> : tensor<3xf64>}> : () -> tensor<3xf64> +// CHECK-NEXT: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[4.000000e+00, 5.000000e+00, 6.000000e+00]> : tensor<3xf64>}> : () -> tensor<3xf64> +// CHECK-NEXT: [[VAR_4_:%.+]] = tosa.reshape [[VAR_2_]] {new_shape = array} : (tensor<3xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_5_:%.+]] = tosa.reshape [[VAR_0_]] {new_shape = array} : (tensor<3xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_6_:%.+]] = tosa.reshape [[VAR_1_]] {new_shape = array} : (tensor<3xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_7_:%.+]] = tosa.reshape [[VAR_3_]] {new_shape = array} : (tensor<3xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_8_:%.+]] = "tosa.const"() <{value = dense<1.0000000656873453E-5> : tensor<1x1x1x1xf64>}> : () -> tensor<1x1x1x1xf64> +// CHECK-NEXT: [[VAR_9_:%.+]] = tosa.sub %arg0, [[VAR_4_]] : (tensor<100x3x10x10xf64>, tensor<1x3x1x1xf64>) -> tensor<100x3x10x10xf64> +// CHECK-NEXT: [[VAR_10_:%.+]] = tosa.add %7, [[VAR_8_]] : (tensor<1x3x1x1xf64>, tensor<1x1x1x1xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_11_:%.+]] = tosa.rsqrt [[VAR_10_]] : (tensor<1x3x1x1xf64>) -> tensor<1x3x1x1xf64> +// CHECK-NEXT: [[VAR_12_:%.+]] = tosa.mul [[VAR_9_]], %11 {shift = 0 : i8} : (tensor<100x3x10x10xf64>, tensor<1x3x1x1xf64>) -> tensor<100x3x10x10xf64> +// CHECK-NEXT: [[VAR_13_:%.+]] = tosa.mul [[VAR_12_]], %5 {shift = 0 : i8} : (tensor<100x3x10x10xf64>, tensor<1x3x1x1xf64>) -> tensor<100x3x10x10xf64> +// CHECK-NEXT: [[VAR_14_:%.+]] = tosa.add [[VAR_13_]], [[VAR_6_]] : (tensor<100x3x10x10xf64>, tensor<1x3x1x1xf64>) -> tensor<100x3x10x10xf64> +// CHECK-NEXT: return [[VAR_14_]] : tensor<100x3x10x10xf64> } diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir index 1174f38e46..8d8a243223 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir @@ -240,3 +240,21 @@ func.func @test_resize_cubic_disallowed(%arg0: tensor<1x1x2x4xf32>) -> tensor<1x // CHECK-LABEL: func.func @test_resize_cubic_disallowed // CHECK-LABEL: onnx.Resize } + + +// ----- + +func.func @test_resize_half_pixel_nearest_floor_downsample_axis_one_fp8(%arg0: tensor<1x1x1x12xf8E4M3FN>) -> tensor<1x1x1x6xf8E4M3FN> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Constant"() {value = dense<[6]> : tensor<1xi64>} : () -> tensor<1xi64> + %2 = "onnx.Resize"(%arg0, %0, %0, %1) {axes = [3], coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, mode = "nearest", nearest_mode = "floor"} : (tensor<1x1x1x12xf8E4M3FN>, none, none, tensor<1xi64>) -> tensor<1x1x1x6xf8E4M3FN> + return %2 : tensor<1x1x1x6xf8E4M3FN> +// CHECK-LABEL: @test_resize_half_pixel_nearest_floor_downsample_axis_one_fp8 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x1x12xf8E4M3FN>) -> tensor<1x1x1x6xf8E4M3FN> { +// CECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CECK-NEXT: [[VAR_1_:%.+]] = tosa.transpose [[PARAM_0_]], [[PARAM_0_]] : (tensor<1x1x1x12xf8E4M3FN>, tensor<4xi32>) -> tensor<1x1x12x1xf8E4M3FN> +// CECK-NEXT: [[VAR_2_:%.+]] = tosa.resize [[VAR_1_]] {border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} : (tensor<1x1x12x1xf8E4M3FN>) -> tensor<1x1x6x1xf8E4M3FN> +// CECK-NEXT: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> +// CECK-NEXT: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<1x1x6x1xf8E4M3FN>, tensor<4xi32>) -> tensor<1x1x1x6xf8E4M3FN> +// CECK-NEXT: return [[VAR_4_]] : tensor<1x1x1x6xf8E4M3FN> + } From e2993ca67e27c7a4a35f4b12bdb4cce98f91a3d3 Mon Sep 17 00:00:00 2001 From: Jose Lopes Date: Wed, 8 May 2024 15:38:20 +0100 Subject: [PATCH 06/17] Add test for Reshape fp8 --- src/Dialect/ONNX/ONNXOps.td.inc | 4 +-- .../conversion/onnx_to_tosa/Flow/fp8.mlir | 36 ------------------- .../onnx_to_tosa/Tensor/Reshape.mlir | 11 ++++++ .../onnx_to_tosa/Tensor/Resize.mlir | 18 ---------- 4 files changed, 13 insertions(+), 56 deletions(-) delete mode 100644 test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir diff --git a/src/Dialect/ONNX/ONNXOps.td.inc b/src/Dialect/ONNX/ONNXOps.td.inc index a3ebf40fba..35512ffe73 100644 --- a/src/Dialect/ONNX/ONNXOps.td.inc +++ b/src/Dialect/ONNX/ONNXOps.td.inc @@ -7241,7 +7241,7 @@ def ONNXResizeOp:ONNX_Op<"Resize", ``` if input \\"sizes\\" is not specified. }]; - let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[AnyFloat]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, + let arguments = (ins AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$X, AnyTypeOf<[TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, NoneType]>:$roi, AnyTypeOf<[TensorOf<[F32]>, NoneType]>:$scales, AnyTypeOf<[TensorOf<[I64]>, NoneType]>:$sizes, @@ -7254,7 +7254,7 @@ def ONNXResizeOp:ONNX_Op<"Resize", DefaultValuedStrAttr:$keep_aspect_ratio_policy, DefaultValuedStrAttr:$mode, DefaultValuedStrAttr:$nearest_mode); - let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[AnyFloat]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y); + let results = (outs AnyTypeOf<[TensorOf<[UI8]>, TensorOf<[UI16]>, TensorOf<[UI32]>, TensorOf<[UI64]>, TensorOf<[I8]>, TensorOf<[I16]>, TensorOf<[I32]>, TensorOf<[I64]>, TensorOf<[BF16]>, TensorOf<[F16]>, TensorOf<[F32]>, TensorOf<[F64]>, TensorOf<[StringType]>, TensorOf<[I1]>, TensorOf<[Complex]>, TensorOf<[Complex]>]>:$Y); let extraClassDeclaration = [{ static int getNumberOfOperands() { return 4; diff --git a/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir b/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir deleted file mode 100644 index e4d97b845c..0000000000 --- a/test/mlir/conversion/onnx_to_tosa/Flow/fp8.mlir +++ /dev/null @@ -1,36 +0,0 @@ -// RUN: onnx-mlir-opt --convert-onnx-to-tosa %s -split-input-file | FileCheck %s - -func.func @test_f8E4M3FNUZ(%arg0: tensor<13x21x3xf8E4M3FNUZ>) -> tensor<13x21x3xf8E4M3FNUZ> { - func.return %arg0 : tensor<13x21x3xf8E4M3FNUZ> -} -// CHECK: @test_f8E4M3FNUZ(%[[arg0:.*]]: tensor<13x21x3xf8E4M3FNUZ>) -> tensor<13x21x3xf8E4M3FNUZ> -// CHECK-NEXT: return %[[arg0]] : tensor<13x21x3xf8E4M3FNUZ> - - -func.func @test_f8E4M3FN(%arg0: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> { - func.return %arg0 : tensor<13x21x3xf8E4M3FN> -} -// CHECK: @test_f8E4M3FN(%[[arg0:.*]]: tensor<13x21x3xf8E4M3FN>) -> tensor<13x21x3xf8E4M3FN> -// CHECK-NEXT: return %[[arg0]] : tensor<13x21x3xf8E4M3FN> - -func.func @test_f8E5M2(%arg0: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> { - func.return %arg0 : tensor<13x21x3xf8E5M2> -} - -// CHECK: func.func @test_f8E5M2(%[[arg0:.*]]: tensor<13x21x3xf8E5M2>) -> tensor<13x21x3xf8E5M2> -// CHECK: return %[[arg0]] : tensor<13x21x3xf8E5M2> - -func.func @test_f8E5M2FNUZ(%arg0: tensor<13x21x3xf8E5M2FNUZ>) -> tensor<13x21x3xf8E5M2FNUZ> { - func.return %arg0 : tensor<13x21x3xf8E5M2FNUZ> -} - -// CHECK: @test_f8E5M2FNUZ(%[[arg0:.*]]: tensor<13x21x3xf8E5M2FNUZ>) -> tensor<13x21x3xf8E5M2FNUZ> { -// CHECK-NEXT: return %[[arg0]] : tensor<13x21x3xf8E5M2FNUZ> - -func.func @test_f8E4M3B11FNUZ(%arg0: tensor<13x21x3xf8E4M3B11FNUZ>) -> tensor<13x21x3xf8E4M3B11FNUZ> { - func.return %arg0 : tensor<13x21x3xf8E4M3B11FNUZ> -} - -// CHECK: @test_f8E4M3B11FNUZ(%[[arg0:.*]]: tensor<13x21x3xf8E4M3B11FNUZ>) -> tensor<13x21x3xf8E4M3B11FNUZ> { -// CHECK-NEXT: return %[[arg0]] : tensor<13x21x3xf8E4M3B11FNUZ> - diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir index 73a653be96..9f550cd852 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir @@ -19,3 +19,14 @@ func.func @test_reshape_allowzero(%arg0 : tensor<12x128x1024xf32>) -> tensor<12x // CHECK: [[VAR_1_:%.+]] = tosa.reshape [[PARAM_0_]] {new_shape = array} : (tensor<12x128x1024xf32>) -> tensor<12x128x16x64xf32> // CHECK-NEXT: return [[VAR_1_]] : tensor<12x128x16x64xf32> } + +func.func @test_reshape_fp8(%arg0 : tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> { + %0 = "onnx.Constant"() {value = dense<[-1, 128, 16, 64]> : tensor<4xi64>} : () -> tensor<4xi64> + %1 = "onnx.Reshape"(%arg0, %0) : (tensor<128x1024xf8E5M2FNUZ>, tensor<4xi64>) -> tensor<1x128x16x64xf8E5M2FNUZ> + "func.return"(%1) : (tensor<1x128x16x64xf8E5M2FNUZ>) -> () +// CHECK-LABEL: @test_reshape_fp8 +// CHECK-SAME: ([[PARAM_0_:%.+]] tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> { +// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[-1, 128, 16, 64]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape %arg0 {new_shape = array} : (tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> +// CHECK-NEXT: return [[VAR_1_]] : tensor<1x128x16x64xf8E5M2FNUZ> + } diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir index 8d8a243223..1174f38e46 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir @@ -240,21 +240,3 @@ func.func @test_resize_cubic_disallowed(%arg0: tensor<1x1x2x4xf32>) -> tensor<1x // CHECK-LABEL: func.func @test_resize_cubic_disallowed // CHECK-LABEL: onnx.Resize } - - -// ----- - -func.func @test_resize_half_pixel_nearest_floor_downsample_axis_one_fp8(%arg0: tensor<1x1x1x12xf8E4M3FN>) -> tensor<1x1x1x6xf8E4M3FN> { - %0 = "onnx.NoValue"() {value} : () -> none - %1 = "onnx.Constant"() {value = dense<[6]> : tensor<1xi64>} : () -> tensor<1xi64> - %2 = "onnx.Resize"(%arg0, %0, %0, %1) {axes = [3], coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, mode = "nearest", nearest_mode = "floor"} : (tensor<1x1x1x12xf8E4M3FN>, none, none, tensor<1xi64>) -> tensor<1x1x1x6xf8E4M3FN> - return %2 : tensor<1x1x1x6xf8E4M3FN> -// CHECK-LABEL: @test_resize_half_pixel_nearest_floor_downsample_axis_one_fp8 -// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x1x1x12xf8E4M3FN>) -> tensor<1x1x1x6xf8E4M3FN> { -// CECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[0, 2, 3, 1]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CECK-NEXT: [[VAR_1_:%.+]] = tosa.transpose [[PARAM_0_]], [[PARAM_0_]] : (tensor<1x1x1x12xf8E4M3FN>, tensor<4xi32>) -> tensor<1x1x12x1xf8E4M3FN> -// CECK-NEXT: [[VAR_2_:%.+]] = tosa.resize [[VAR_1_]] {border = array, mode = "NEAREST_NEIGHBOR", offset = array, scale = array} : (tensor<1x1x12x1xf8E4M3FN>) -> tensor<1x1x6x1xf8E4M3FN> -// CECK-NEXT: [[VAR_3_:%.+]] = "tosa.const"() <{value = dense<[0, 3, 1, 2]> : tensor<4xi32>}> : () -> tensor<4xi32> -// CECK-NEXT: [[VAR_4_:%.+]] = tosa.transpose [[VAR_2_]], [[VAR_3_]] : (tensor<1x1x6x1xf8E4M3FN>, tensor<4xi32>) -> tensor<1x1x1x6xf8E4M3FN> -// CECK-NEXT: return [[VAR_4_]] : tensor<1x1x1x6xf8E4M3FN> - } From 677575edaf434e5118060cd8d5e0c22f91674e48 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 12:11:58 +0100 Subject: [PATCH 07/17] Change include style --- src/Conversion/ONNXToTOSA/Tensor/Resize.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp index 54341c1be3..620bad5697 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp @@ -13,13 +13,13 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Tosa/IR/TosaOps.h" +#include "mlir/IR/BuiltinTypes.h" #include "src/Conversion/ONNXToTOSA/DialectBuilder.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSACommon.hpp" #include "src/Conversion/ONNXToTOSA/ONNXToTOSALegalizeUtils.hpp" #include -#include #include using namespace mlir; From cd1a71a602de277d846fbf53e33d90d3bc7b5a83 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 12:12:26 +0100 Subject: [PATCH 08/17] Drop check on unused constant --- test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir index 9f550cd852..4501f1ed7e 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Reshape.mlir @@ -26,7 +26,6 @@ func.func @test_reshape_fp8(%arg0 : tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128 "func.return"(%1) : (tensor<1x128x16x64xf8E5M2FNUZ>) -> () // CHECK-LABEL: @test_reshape_fp8 // CHECK-SAME: ([[PARAM_0_:%.+]] tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> { -// CHECK-NEXT: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<[-1, 128, 16, 64]> : tensor<4xi64>}> : () -> tensor<4xi64> -// CHECK-NEXT: [[VAR_1_:%.+]] = tosa.reshape %arg0 {new_shape = array} : (tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> +// CHECK: [[VAR_1_:%.+]] = tosa.reshape %arg0 {new_shape = array} : (tensor<128x1024xf8E5M2FNUZ>) -> tensor<1x128x16x64xf8E5M2FNUZ> // CHECK-NEXT: return [[VAR_1_]] : tensor<1x128x16x64xf8E5M2FNUZ> } From 9de956f58d58c40455d763152a346610c5fcb3c3 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 12:48:21 +0100 Subject: [PATCH 09/17] Use isa<>(x) instead x.isa<> (this is deprecated) See https://mlir.llvm.org/deprecation/ --- src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp | 2 +- src/Conversion/ONNXToTOSA/Math/Elementwise.cpp | 6 +++--- src/Conversion/ONNXToTOSA/Tensor/Resize.cpp | 2 +- src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp index c9c21724ec..19396270db 100644 --- a/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp +++ b/src/Conversion/ONNXToTOSA/ConvertONNXToTOSA.cpp @@ -120,7 +120,7 @@ void FrontendToTosaLoweringPass::runOnOperation() { // conversion failures. Quantized types are not supported right now. TypeConverter typeConverter; typeConverter.addConversion([](Type type) -> std::optional { - if (isTOSAInt(type) || type.isa() || type.isa() || + if (isTOSAInt(type) || isa(type) || type.isa() || isTOSABool(type)) return type; return std::nullopt; diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 115e067e96..14464628cb 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -60,7 +60,7 @@ struct ErfIOSupportedTypes { struct IsAnyLegalType { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!scalarType.isa() && !isTOSAInt(scalarType) && + if (!isa(scalarType) && !isTOSAInt(scalarType) && !isTOSABool(scalarType)) { return rewriter.notifyMatchFailure( op, "this operation only supports signed integer or float types"); @@ -72,7 +72,7 @@ struct IsAnyLegalType { struct IsIntOrFloat { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!scalarType.isa() && !isTOSAInt(scalarType)) { + if (!isa(scalarType) && !isTOSAInt(scalarType)) { return rewriter.notifyMatchFailure( op, "this operation only supports signed integer or float types"); } @@ -94,7 +94,7 @@ struct IsInt { struct IsFloat { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!scalarType.isa()) { + if (!isa(scalarType)) { return rewriter.notifyMatchFailure( op, "this operation only supports float types"); } diff --git a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp index 620bad5697..0ddb55f0af 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Resize.cpp @@ -204,7 +204,7 @@ class ONNXResizeOpLoweringToTOSA : public ConversionPattern { } auto elementType = inputType.getElementType(); - if (!(elementType.isa() || isTOSAInt(elementType))) { + if (!(isa(elementType) || isTOSAInt(elementType))) { return rewriter.notifyMatchFailure( resizeOp, "Element type is not supported by TOSA."); } diff --git a/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp b/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp index 0b333b0e2c..100402347d 100644 --- a/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp +++ b/src/Conversion/ONNXToTOSA/Tensor/Transpose.cpp @@ -53,7 +53,7 @@ class ONNXTransposeLoweringToTOSA Type inputElementType = inputType.getElementType(); - if (!inputElementType.isa() && !isTOSAInt(inputElementType) && + if (!isa(inputElementType) && !isTOSAInt(inputElementType) && !inputElementType.isInteger(1)) { return rewriter.notifyMatchFailure( op, "input element type not supported"); From 3a7e8602feebd4aabaa16b900d0386d8ee6c13c5 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 12:52:19 +0100 Subject: [PATCH 10/17] Drop comment --- test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir | 1 - 1 file changed, 1 deletion(-) diff --git a/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir b/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir index f28e8d2a20..0c47f22ad4 100644 --- a/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir +++ b/test/mlir/conversion/onnx_to_tosa/NN/BatchNorm.mlir @@ -85,7 +85,6 @@ func.func @test_batchnorm_bf16_dynamic(%arg0: tensor<100x3x?x?xbf16>) -> tensor< } // ----- -// tosa doesn't support f64, but it is still lowered func.func @test_batchnorm_f64(%arg0: tensor<100x3x10x10xf64>) -> tensor<100x3x10x10xf64> { %0 = "onnx.Constant"() {value = dense<[1.0, 2.0, 3.0]> : tensor<3xf64>} : () -> tensor<3xf64> From 33df7898accbe360a00c959c6e25268dfae3ce95 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 13:45:49 +0100 Subject: [PATCH 11/17] Add test for resize with f64 --- test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir index 1174f38e46..af185ced75 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Resize.mlir @@ -102,6 +102,18 @@ func.func @test_resize_half_pixel_nearest_floor_downsample(%arg0: tensor<1x1x1x1 // ----- +func.func @test_resize_f64(%arg0: tensor<1x1x1x4xf64>) -> tensor<1x1x1x12xf64> { + %0 = "onnx.NoValue"() {value} : () -> none + %1 = "onnx.Constant"() {value = dense<[1, 1, 1, 12]> : tensor<4xi64>} : () -> tensor<4xi64> + %2 = "onnx.Resize"(%arg0, %0, %0, %1) {coordinate_transformation_mode = "half_pixel", cubic_coeff_a = -7.500000e-01 : f32, exclude_outside = 0 : si64, extrapolation_value = 0.000000e+00 : f32, mode = "nearest", nearest_mode = "floor"} : (tensor<1x1x1x4xf64>, none, none, tensor<4xi64>) -> tensor<1x1x1x12xf64> + return %2 : tensor<1x1x1x12xf64> +// CHECK-LABEL: func.func @test_resize_f64 +// CHECK-NOT: onnx.Resize +// CHECK: return {{.*}}: tensor<1x1x1x12xf64> +} + +// ----- + func.func @test_resize_input_one(%arg0: tensor<1x1x1x1xf32>) -> tensor<1x1x4x4xf32> { %0 = "onnx.NoValue"() {value} : () -> none %1 = "onnx.Constant"() {value = dense<[1, 1, 4, 4]> : tensor<4xi64>} : () -> tensor<4xi64> From 2fd5a775008898023d4ebd0dd5f408e59a854b87 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 13:55:56 +0100 Subject: [PATCH 12/17] Add f64 onnx to tosa transpose test Float8 types start to be supported from V21. --- test/mlir/conversion/onnx_to_tosa/Tensor/Transpose.mlir | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/mlir/conversion/onnx_to_tosa/Tensor/Transpose.mlir b/test/mlir/conversion/onnx_to_tosa/Tensor/Transpose.mlir index 4b3bb6f80f..a95eeebf2c 100644 --- a/test/mlir/conversion/onnx_to_tosa/Tensor/Transpose.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Tensor/Transpose.mlir @@ -19,3 +19,11 @@ func.func @test_transpose(%arg0 : tensor<5x5x1x32xf32>) -> tensor<5x1x32x5xf32> // CHECK: %[[VAL_2:.*]] = tosa.transpose %[[VAL_0]], %[[VAL_1]] : (tensor<5x5x1x32xf32>, tensor<4xi32>) -> tensor<5x1x32x5xf32> // CHECK: return %[[VAL_2]] : tensor<5x1x32x5xf32> } + +func.func @test_transpose_f64(%arg0 : tensor<5x5x1x32xf64>) -> tensor<5x1x32x5xf64> { + %0 = "onnx.Transpose"(%arg0) {perm = [0, 2, 3, 1]} : (tensor<5x5x1x32xf64>) -> tensor<5x1x32x5xf64> + return %0 : tensor<5x1x32x5xf64> +// CHECK-LABEL: func.func @test_transpose +// CHECK-NOT: onnx.Transpose +// CHECK: return {{.*}}: tensor<5x1x32x5xf64> +} From b5a0e48dc9eb75113cb8b47ad4f5190e84ddbe85 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 14:20:29 +0100 Subject: [PATCH 13/17] Drop unused function --- src/Conversion/ONNXToTOSA/Math/Elementwise.cpp | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 14464628cb..39b7b3c230 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -57,18 +57,6 @@ struct ErfIOSupportedTypes { } }; -struct IsAnyLegalType { - static LogicalResult checkType( - ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!isa(scalarType) && !isTOSAInt(scalarType) && - !isTOSABool(scalarType)) { - return rewriter.notifyMatchFailure( - op, "this operation only supports signed integer or float types"); - } - return success(); - } -}; - struct IsIntOrFloat { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { From ddcfb491d80f8107dea9263f43d68611e319d534 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 14:20:47 +0100 Subject: [PATCH 14/17] Add tests for changed type check functions --- .../onnx_to_tosa/Math/Elementwise.mlir | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 2b0d1a6532..d626092c95 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -103,6 +103,15 @@ func.func @test_add_ui32(%arg0: tensor<13x21x1xui32>, %arg1: tensor<13x21x1xui32 // CHECK: return [[VAR_0_]] : tensor<13x21x1xui32> } +// ----- + +func.func @test_add_f64(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x1xf64>) -> tensor<13x21x1xf64> { + %0 = "onnx.Add"(%arg0, %arg1) : (tensor<13x21x1xf64>, tensor<13x21x1xf64>) -> tensor<13x21x1xf64> + "func.return"(%0) : (tensor<13x21x1xf64>) -> () +// CHECK-LABEL: func.func @test_add_f64 +// CHECK-NOT: onnx.Add +// CHECK: return {{.*}}: tensor<13x21x1xf64> +} // ----- @@ -484,6 +493,14 @@ func.func @test_pow_broadcast(%arg0: tensor<13x21x1xf32>, %arg1: tensor<1xf32>) // CHECK-NEXT: [[VAR_1_:%.+]] = tosa.pow [[PARAM_0_]], [[VAR_0_]] : (tensor<13x21x1xf32>, tensor<1x1x1xf32>) -> tensor<13x21x1xf32> } +func.func @test_pow_f64(%arg0: tensor<13x21x1xf64>, %arg1: tensor<13x21x1xf64>) -> tensor<13x21x1xf64> { + %0 = "onnx.Pow"(%arg0, %arg1) : (tensor<13x21x1xf64>, tensor<13x21x1xf64>) -> tensor<13x21x1xf64> + "func.return"(%0) : (tensor<13x21x1xf64>) -> () +// CHECK-LABEL: func @test_pow +// CHECK-NOT: onnx.Pow +// CHECK: return {{.*}}: tensor<13x21x1xf64> +} + // ----- func.func @test_sqrt(%arg0: tensor<3xf32>) -> tensor<3xf32> { From c8919091e0444132a739b70416bcb2b69c255f6d Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 14:28:45 +0100 Subject: [PATCH 15/17] Drop special restrictions on ERF --- src/Conversion/ONNXToTOSA/Math/Elementwise.cpp | 14 +------------- .../conversion/onnx_to_tosa/Math/Elementwise.mlir | 8 ++++++++ 2 files changed, 9 insertions(+), 13 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 39b7b3c230..8500a7ba1a 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -45,18 +45,6 @@ struct AbsIOSupportedTypes { } }; -struct ErfIOSupportedTypes { - static LogicalResult checkType( - ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!mlir::isa( - scalarType)) { - return rewriter.notifyMatchFailure( - op, "this operation only supports fp16, fp32 or bf16"); - } - return success(); - } -}; - struct IsIntOrFloat { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { @@ -747,7 +735,7 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern( ONNXElementwiseUnaryOpLoweringToTOSA, ONNXElementwiseUnaryOpLoweringToTOSA>(typeConverter, ctx); + IsFloat, IsFloat>>(typeConverter, ctx); // Tosa custom ops #define INSERT_ONNX_UNARY_TO_TOSA_CUSTOMOP_PATTERN( \ diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index d626092c95..9dcfc12253 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -558,6 +558,14 @@ func.func @test_erf_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> { // CHECK-NEXT: } } +func.func @test_erf_f64(%arg0: tensor<3xf64>) -> tensor<3xf64> { + %0 = "onnx.Erf"(%arg0) : (tensor<3xf64>) -> tensor<3xf64> + return %0 : tensor<3xf64> +// CHECK-LABEL: func @test_erf_f64 +// CHECK-NOT: onnx.Erf +// CHECK: return %0 : tensor<3xf64> +} + // ----- func.func @test_bitwise_not(%arg0 : tensor<10x10xi32>) -> tensor<10x10xi32> { From 2aa17eb86ab53020eae975d5e30f7d689d569db0 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 14:33:27 +0100 Subject: [PATCH 16/17] Drop special restrictions on ABS --- src/Conversion/ONNXToTOSA/Math/Elementwise.cpp | 16 +--------------- .../onnx_to_tosa/Math/Elementwise.mlir | 8 ++++++++ 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp index 8500a7ba1a..4d6368a73e 100644 --- a/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp +++ b/src/Conversion/ONNXToTOSA/Math/Elementwise.cpp @@ -31,20 +31,6 @@ struct TOSADialectOp { using Op = mlir::tosa::NegateOp; }; -struct AbsIOSupportedTypes { - static LogicalResult checkType( - ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { - if (!mlir::isa( - scalarType) && - !scalarType.isSignlessInteger(/*width=*/32)) { - return rewriter.notifyMatchFailure(op, - "this operation only supports signless 32 integer or fp16, fp32" - " or bf16"); - } - return success(); - } -}; - struct IsIntOrFloat { static LogicalResult checkType( ConversionPatternRewriter &rewriter, Type scalarType, Operation *op) { @@ -733,7 +719,7 @@ static void populateLoweringONNXElementwiseUnaryTemplateOpToTOSAPattern( ONNXElementwiseUnaryOpLoweringToTOSA, ONNXElementwiseUnaryOpLoweringToTOSA, + IsIntOrFloat, IsIntOrFloat>, ONNXElementwiseUnaryOpLoweringToTOSA>(typeConverter, ctx); diff --git a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir index 9dcfc12253..299f6f07e1 100644 --- a/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir +++ b/test/mlir/conversion/onnx_to_tosa/Math/Elementwise.mlir @@ -536,6 +536,14 @@ func.func @test_abs_bf16(%arg0: tensor<3xbf16>) -> tensor<3xbf16> { // CHECK-NEXT: } } +func.func @test_abs_f64(%arg0: tensor<3xf64>) -> tensor<3xf64> { + %0 = "onnx.Abs"(%arg0) : (tensor<3xf64>) -> tensor<3xf64> + return %0 : tensor<3xf64> +// CHECK-LABEL: func @test_abs_f64 +// CHECK-NOT: onnx.Abs +// CHECK: return {{.*}}: tensor<3xf64> +} + // ----- func.func @test_erf_f32(%arg0: tensor<3xf32>) -> tensor<3xf32> { From 265313ba45ee7b7ac5fe7fdae0c39add1ee1e9f1 Mon Sep 17 00:00:00 2001 From: Tina Jung Date: Fri, 10 May 2024 14:56:26 +0100 Subject: [PATCH 17/17] Update mlir version to include more tosa floats --- utils/clone-mlir.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/clone-mlir.sh b/utils/clone-mlir.sh index c42835959a..2cadbf8cdb 100644 --- a/utils/clone-mlir.sh +++ b/utils/clone-mlir.sh @@ -1,3 +1,3 @@ git clone -n https://github.com/xilinx/llvm-project.git # Check out a specific branch that is known to work with ONNX-MLIR. -cd llvm-project && git checkout af893daa8d19 && cd .. +cd llvm-project && git checkout fda272652fd65e139ed162a9c7ce521133eb34a0 && cd ..