From 20ae283d087224f6b82b7308054bd34a6764d926 Mon Sep 17 00:00:00 2001 From: Tai Ly Date: Tue, 11 Feb 2025 13:02:44 -0600 Subject: [PATCH] [mlir][tosa] Change the shift of mul to be required (#125297) Change the shift operand for the mul operator to be a required operand. Also defined shift to be Tosa_ScalarInt8Tensor which requires that it is a rank-1 tensor whose shape is [1] (ie, tensor containing a single element) Signed-off-by: Tai Ly --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 21 +++++----- .../mlir/Dialect/Tosa/IR/TosaTypesBase.td | 20 +++++---- .../Conversion/TosaToLinalg/TosaToLinalg.cpp | 21 ++++++---- mlir/lib/Dialect/Tosa/IR/TosaOps.cpp | 21 ++++++---- .../Tosa/Transforms/TosaReduceTransposes.cpp | 6 +-- .../TosaToLinalg/tosa-to-linalg.mlir | 6 ++- mlir/test/Dialect/Tosa/canonicalize.mlir | 14 ++++--- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 15 ++++--- mlir/test/Dialect/Tosa/invalid.mlir | 42 +++++++++++++++---- mlir/test/Dialect/Tosa/ops.mlir | 3 +- mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir | 34 ++++++++------- .../Dialect/Tosa/tosa-reduce-transposes.mlir | 10 +++-- 12 files changed, 135 insertions(+), 78 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 7a65417db0eab..b8755da8db32e 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -105,8 +105,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Optional:$input_zp, - Optional:$weight_zp, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, @@ -136,8 +136,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> { Tosa_Tensor5D:$input, TosaTensorRankOf<[Tosa_Weight], [5]>:$weight, Tosa_Tensor1D:$bias, - Optional:$input_zp, - Optional:$weight_zp, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr6:$pad, Tosa_IntArrayAttr3:$stride, Tosa_IntArrayAttr3:$dilation, @@ -168,8 +168,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Optional:$input_zp, - Optional:$weight_zp, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, @@ -356,8 +356,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Optional:$input_zp, - Optional:$weight_zp, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$out_pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$out_shape, @@ -817,7 +817,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [ let arguments = (ins Tosa_Tensor:$input1, Tosa_Tensor:$input2, - Optional>:$shift + // Apply right shift on i32_t input data only + Tosa_ScalarInt8Tensor:$shift ); let results = (outs @@ -1590,7 +1591,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> { let arguments = (ins Tosa_RankedTensor:$input1, Tosa_Shape:$padding, - Optional:$pad_const, + Optional:$pad_const, OptionalAttr:$input_zp ); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 120adf82249e0..6457bb8749ee0 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[ IsRankedTensorTypePred, CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>; +def AllDimensionsAreSizeOne : And<[ + IsRankedTensorTypePred, + CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>; + class TosaTensorOf< list allowedTypes, string summary = "tosa-conformant tensor"> : TensorOf], summary>; @@ -109,6 +113,11 @@ class TosaTensorRankOf allowedTypes, list ranks> [HasAnyRankOfPred], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; +class TosaScalarTensorOf allowedTypes, list ranks> + : TosaRankedTensorOf, AllDimensionsAreSizeOne], + "tosa-conformant scalar tensor">; + //===----------------------------------------------------------------------===// // Tensor types //===----------------------------------------------------------------------===// @@ -136,8 +145,10 @@ class Tosa_TensorOfOrNone allowedTypes, string description = ""> : // Tensor types with constrained ranks. //===----------------------------------------------------------------------===// -// Rank-0 (scalar) tensor -def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; +def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; + +def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>; +def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>; // We include unranked tensors as a supported type for all possible tosa // Tensors as unranked does not guarantee invalid. If unranked tensors exist @@ -296,9 +307,4 @@ def Rank1TosaShape : TosaShapeOfRank<1>; def Rank2TosaShape : TosaShapeOfRank<2>; def Rank4TosaShape : TosaShapeOfRank<4>; -// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this -// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the -// following def can be removed. -def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>; - #endif // TOSA_TYPES_BASE diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 28bc8732b7978..d849c782bf08b 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -92,22 +92,27 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::MulOp if (isa(op)) { auto shift_val = cast(op).getShift(); + ElementsAttr shift_elem; + if (!shift_val.getImpl() || + !matchPattern(shift_val, m_Constant(&shift_elem))) { + (void)rewriter.notifyMatchFailure(op, "shift value of mul not found"); + } + + int32_t shift = shift_elem.getValues()[0].getInt(); if (isa(elementTy)) { + if (shift != 0) { + (void)rewriter.notifyMatchFailure(op, + "Cannot have shift value for float"); + return nullptr; + } return rewriter.create(loc, resultTypes, args[0], args[1]); } if (isa(elementTy)) { - int32_t shift = 0; - ElementsAttr shift_elem; - if (shift_val.getImpl() && - matchPattern(shift_val, m_Constant(&shift_elem))) { - // Explicit shift is set. - shift = shift_elem.getValues()[0].getInt(); - } - Value a = args[0]; Value b = args[1]; + if (shift > 0) { auto shiftConst = rewriter.create(loc, shift, /*bitwidth=*/8); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index cff24d825d3f5..4928be38476a9 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1130,16 +1130,10 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents( ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - LogicalResult status = success(); + // mul op's output shape only depend on input1 and input2, not on shift + ValueShapeRange twoInputs = operands.drop_back(); llvm::SmallVector outShape; - if (operands.size() == 2) { - status = resolveBroadcastShape(operands, outShape); - } else { - // mul op's output shape only depend on input1 and input2, not on shift - ValueShapeRange two_inputs = operands.drop_back(); - status = resolveBroadcastShape(two_inputs, outShape); - } - if (status.failed()) { + if (resolveBroadcastShape(twoInputs, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents()); } else { inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); @@ -1174,6 +1168,15 @@ LogicalResult tosa::MulOp::verify() { return emitOpError( "requires the same element type for all operands and results"); } + + // verify shift has value 0 for non-integer types + ElementsAttr shift_elem; + if (matchPattern(getShift(), m_Constant(&shift_elem))) { + int32_t shift = shift_elem.getValues()[0].getInt(); + if (shift != 0) { + return emitOpError() << "require shift to be 0 for float type"; + } + } } // Verify the op has same ranks for all main operands (excludes extra operands diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index 281f0529a5c08..64e5c31793f84 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -287,8 +287,7 @@ bool TosaReduceTransposes::collectFanIn(Operation *op, for (Value operand : op->getOperands()) { // If this is a problem in future, think about alternatives to recursion. - if (llvm::isa(op) && op->getNumOperands() == 3 && - operand == op->getOperand(2)) { + if (llvm::isa(op) && operand == op->getOperand(2)) { // do not recurse into MulOp's shift operand continue; } @@ -332,8 +331,7 @@ std::optional TosaReduceTransposes::buildMappedToValue( for (Value v : op->getOperands()) { if (valuesMap.contains(v)) { operands.push_back(valuesMap.at(v)); - } else if (llvm::isa(op) && op->getNumOperands() == 3 && - v == op->getOperand(2)) { + } else if (llvm::isa(op) && v == op->getOperand(2)) { // special case for MulOp's shift operand operands.push_back(v); } else { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 56521fb67ef0c..17add2d41afe7 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -472,7 +472,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () { // CHECK: linalg.generic // CHECK: arith.mulf - %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: arith.negf @@ -618,7 +619,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () { // CHECK: arith.extsi // CHECK: arith.extsi // CHECK: arith.muli - %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 58a70fb03a092..24d572244a9b0 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -322,8 +322,9 @@ func.func @pad_determine_val_quant(%arg0: tensor, %arg1 : tensor<2x2xi3 func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %arg0 // CHECK-NOT: tosa.mul + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } @@ -334,7 +335,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %arg0 // CHECK-NOT: tosa.mul %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } @@ -370,11 +372,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>} // CHECK-NOT: tosa.mul %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32> // CHECK-NOT: tosa.mul // CHECK: return %[[ZERO]], %[[ZERO]] - %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32> } @@ -974,7 +977,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform : tensor<1xi8>} : () -> tensor<1x!quant.uniform> %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> - %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>)-> tensor<1x!quant.uniform> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1xi8>) -> tensor<1x!quant.uniform> return %2 : tensor<1x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 40469987d89d0..e6fb741df9598 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -238,7 +238,8 @@ func.func @fold_div_splat_i32() -> tensor { func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> - %mul = tosa.mul %arg0, %zero : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[ZERO]] return %mul : tensor } @@ -249,7 +250,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> - %mul = tosa.mul %zero, %arg0 : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[ZERO]] return %mul : tensor } @@ -283,7 +285,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_one_rhs_f32 func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { %one = "tosa.const"() {value = dense<1.0> : tensor} : () -> tensor - %mul = tosa.mul %arg0, %one : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %arg0 return %mul : tensor } @@ -293,7 +296,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_one_lhs_f32 func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { %one = "tosa.const"() {value = dense<1.0> : tensor} : () -> tensor - %mul = tosa.mul %one, %arg0 : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %arg0 return %mul : tensor } @@ -339,7 +343,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> { func.func @fold_mul_splat_f32() -> tensor<10xf32> { %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32> %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32> // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>} // CHECK: return %[[THREE]] return %mul : tensor<10xf32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index e77078161d063..913191be86f85 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -768,26 +768,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, // CHECK-LABEL: test_mul_type_mismatch func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}} - %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf16>) -> tensor<13x21x3xf32> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } // ----- // CHECK-LABEL: test_mul_invalid_shift -func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { - %shift = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor - // expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor'}} - %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor) -> tensor<13x21x3xi32> - return %0 : tensor<13x21x3xi32> +func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> } // ----- // CHECK-LABEL: test_mul_missing_shift func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { - // this is ok because mul's shift operand is optional for now + // expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}} %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> return %0 : tensor<13x21x3xi32> } @@ -1099,3 +1100,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> return %0 : tensor<1x13x21x3xf32> } + +// ----- +// CHECK-LABEL: test_mul_non_scalar_shift_2d +func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_non_scalar_shift_1d +func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8> + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_non_broadcast +func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 0d3dfead4a7a3..348849cfaa572 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -355,7 +355,8 @@ func.func @test_mul_scalar_with_unranked_output(%arg0: tensor, %arg1: tenso // ----- // CHECK-LABEL: mul func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index a0a7e5dec6ed0..7dc9b048085fa 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -114,23 +114,24 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %3 = tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %3 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32> + %4 = tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %5 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %6 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %7 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %8 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %9 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> return } @@ -148,23 +149,24 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32 // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %3 = tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %3 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32> + %4 = tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %5 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %6 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %7 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %8 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %9 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> return } @@ -211,10 +213,10 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () { %11 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<*xi32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + %13 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + %14 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> return } diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir index e70f3644da646..3a293009a5455 100644 --- a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir @@ -193,6 +193,7 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x // CHECK-LABEL: @test_resnet18_common_case // COM: note that %74 is now represented by %arg2 +// CHECK-DAG: %[[CONST0:.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> // CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32> @@ -205,15 +206,16 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x // CHECK-DAG: %[[VAL_12:.*]] = tosa.sub %arg2, %[[VAL_11]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_14:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_13]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_16:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_17:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_16]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_19]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> // CHECK-DAG: %[[VAL_21:.*]] = tosa.add %[[VAL_18]], %[[VAL_20]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_22:.*]] = tosa.clamp %[[VAL_21]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> { + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> %58 = tosa.const_shape {value = dense<[1, 64, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> %59 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> %60 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> @@ -228,9 +230,9 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32 %79 = tosa.reshape %arg0, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> %81 = tosa.reshape %78, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> - %82 = tosa.mul %80, %81 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> + %82 = tosa.mul %80, %81, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32> %83 = tosa.reshape %60, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> - %84 = tosa.mul %82, %83 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> + %84 = tosa.mul %82, %83, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32> %85 = tosa.reshape %59, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> %87 = tosa.clamp %86 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>