diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 21451c83a2cc1..9efc9badd5053 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -105,8 +105,8 @@ def Tosa_Conv2DOp : Tosa_ConvOp<"conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Optional:$input_zp, - Optional:$weight_zp, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, @@ -136,8 +136,8 @@ def Tosa_Conv3DOp : Tosa_ConvOp<"conv3d"> { Tosa_Tensor5D:$input, TosaTensorRankOf<[Tosa_Weight], [5]>:$weight, Tosa_Tensor1D:$bias, - Optional:$input_zp, - Optional:$weight_zp, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr6:$pad, Tosa_IntArrayAttr3:$stride, Tosa_IntArrayAttr3:$dilation, @@ -168,8 +168,8 @@ def Tosa_DepthwiseConv2DOp : Tosa_ConvOp<"depthwise_conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Optional:$input_zp, - Optional:$weight_zp, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr2:$dilation, @@ -356,8 +356,8 @@ def Tosa_TransposeConv2DOp : Tosa_ConvOp<"transpose_conv2d"> { Tosa_Tensor4D:$input, TosaTensorRankOf<[Tosa_Weight], [4]>:$weight, Tosa_Tensor1D:$bias, - Optional:$input_zp, - Optional:$weight_zp, + Optional:$input_zp, + Optional:$weight_zp, Tosa_IntArrayAttr4:$out_pad, Tosa_IntArrayAttr2:$stride, Tosa_IntArrayAttr4:$out_shape, @@ -821,7 +821,8 @@ def Tosa_MulOp : Tosa_Op<"mul", [ let arguments = (ins Tosa_Tensor:$input1, Tosa_Tensor:$input2, - Optional>:$shift + // Apply right shift on i32_t input data only + Tosa_ScalarInt8Tensor:$shift ); let results = (outs @@ -1614,7 +1615,7 @@ def Tosa_PadOp : Tosa_InferShapedTypeOp<"pad"> { let arguments = (ins Tosa_RankedTensor:$input1, Tosa_Shape:$padding, - Optional:$pad_const, + Optional:$pad_const, OptionalAttr:$input_zp ); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 28b2f90c90052..5685baae724c7 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -93,6 +93,10 @@ def HasNo0Dimensions : And<[ IsRankedTensorTypePred, CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v != 0; })">]>; +def AllDimensionsAreSizeOne : And<[ + IsRankedTensorTypePred, + CPred<"::llvm::all_of(::llvm::cast<::mlir::RankedTensorType>($_self).getShape(), [](auto v) { return v == 1; })">]>; + // AMD: removed HasNo0Dimensions constraint below to allow lowerings // in onnx-mlir like onnx.Split. class TosaTensorOf< @@ -111,6 +115,11 @@ class TosaTensorRankOf allowedTypes, list ranks> [HasAnyRankOfPred], !interleave(!foreach(rank, ranks, rank # "D"), "/") # " tensor">; +class TosaScalarTensorOf allowedTypes, list ranks> + : TosaRankedTensorOf, AllDimensionsAreSizeOne], + "tosa-conformant scalar tensor">; + //===----------------------------------------------------------------------===// // Tensor types //===----------------------------------------------------------------------===// @@ -139,8 +148,10 @@ class Tosa_TensorOfOrNone allowedTypes, string description = ""> : // Tensor types with constrained ranks. //===----------------------------------------------------------------------===// -// Rank-0 (scalar) tensor -def Tosa_ScalarTensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; +def Tosa_Rank0Tensor : TosaTensorRankOf<[Tosa_AnyNumber], [0]>; + +def Tosa_ScalarTensor : TosaScalarTensorOf<[Tosa_AnyNumber], [1]>; +def Tosa_ScalarInt8Tensor : TosaScalarTensorOf<[Tosa_Int8], [1]>; // We include unranked tensors as a supported type for all possible tosa // Tensors as unranked does not guarantee invalid. If unranked tensors exist @@ -299,9 +310,4 @@ def Rank1TosaShape : TosaShapeOfRank<1>; def Rank2TosaShape : TosaShapeOfRank<2>; def Rank4TosaShape : TosaShapeOfRank<4>; -// NOTE: Tosa_ScalarTensor is currently defined as rank-0. If and when this -// becomes rank-1 it can be used in place of Tosa_ZeroPointTensor and the -// following def can be removed. -def Tosa_ZeroPointTensor : TosaTensorRankOf<[Tosa_AnyNumber], [1]>; - #endif // TOSA_TYPES_BASE diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 2c291fc12430c..b84b2fddd047d 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -102,22 +102,27 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::MulOp if (isa(op)) { auto shift_val = cast(op).getShift(); + ElementsAttr shift_elem; + if (!shift_val.getImpl() || + !matchPattern(shift_val, m_Constant(&shift_elem))) { + (void)rewriter.notifyMatchFailure(op, "shift value of mul not found"); + } + + int32_t shift = shift_elem.getValues()[0].getInt(); if (isa(elementTy)) { + if (shift != 0) { + (void)rewriter.notifyMatchFailure(op, + "Cannot have shift value for float"); + return nullptr; + } return rewriter.create(loc, resultTypes, args[0], args[1]); } if (isa(elementTy)) { - int32_t shift = 0; - ElementsAttr shift_elem; - if (shift_val.getImpl() && - matchPattern(shift_val, m_Constant(&shift_elem))) { - // Explicit shift is set. - shift = shift_elem.getValues()[0].getInt(); - } - Value a = args[0]; Value b = args[1]; + if (shift > 0) { auto shiftConst = rewriter.create(loc, shift, /*bitwidth=*/8); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index f1fcc17cf360f..3f348ba3acd6e 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1204,16 +1204,10 @@ LogicalResult tosa::MulOp::inferReturnTypeComponents( ValueShapeRange operands, DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions, SmallVectorImpl &inferredReturnShapes) { - LogicalResult status = success(); + // mul op's output shape only depend on input1 and input2, not on shift + ValueShapeRange twoInputs = operands.drop_back(); llvm::SmallVector outShape; - if (operands.size() == 2) { - status = resolveBroadcastShape(operands, outShape); - } else { - // mul op's output shape only depend on input1 and input2, not on shift - ValueShapeRange two_inputs = operands.drop_back(); - status = resolveBroadcastShape(two_inputs, outShape); - } - if (status.failed()) { + if (resolveBroadcastShape(twoInputs, outShape).failed()) { inferredReturnShapes.push_back(ShapedTypeComponents()); } else { inferredReturnShapes.push_back(ShapedTypeComponents(outShape)); @@ -1248,6 +1242,15 @@ LogicalResult tosa::MulOp::verify() { return emitOpError( "requires the same element type for all operands and results"); } + + // verify shift has value 0 for non-integer types + ElementsAttr shift_elem; + if (matchPattern(getShift(), m_Constant(&shift_elem))) { + int32_t shift = shift_elem.getValues()[0].getInt(); + if (shift != 0) { + return emitOpError() << "require shift to be 0 for float type"; + } + } } // Verify the op has same ranks for all main operands (excludes extra operands diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp index 281f0529a5c08..64e5c31793f84 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaReduceTransposes.cpp @@ -287,8 +287,7 @@ bool TosaReduceTransposes::collectFanIn(Operation *op, for (Value operand : op->getOperands()) { // If this is a problem in future, think about alternatives to recursion. - if (llvm::isa(op) && op->getNumOperands() == 3 && - operand == op->getOperand(2)) { + if (llvm::isa(op) && operand == op->getOperand(2)) { // do not recurse into MulOp's shift operand continue; } @@ -332,8 +331,7 @@ std::optional TosaReduceTransposes::buildMappedToValue( for (Value v : op->getOperands()) { if (valuesMap.contains(v)) { operands.push_back(valuesMap.at(v)); - } else if (llvm::isa(op) && op->getNumOperands() == 3 && - v == op->getOperand(2)) { + } else if (llvm::isa(op) && v == op->getOperand(2)) { // special case for MulOp's shift operand operands.push_back(v); } else { diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index a0df696c53b2d..2effc2194138b 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -492,7 +492,8 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () { // CHECK: linalg.generic // CHECK: arith.mulf - %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %4 = tosa.mul %0, %1, %shift : (tensor<1xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: arith.negf @@ -658,7 +659,8 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () { // CHECK: arith.extsi // CHECK: arith.extsi // CHECK: arith.muli - %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg0, %shift : (tensor<1xi16>, tensor<1xi16>, tensor<1xi8>) -> tensor<1xi32> return } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index c4e655d6d25ce..c160fca2cbc1f 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -417,8 +417,9 @@ func.func @pad_determine_val_quant(%arg0: tensor, %arg1 : tensor<2x2xi3 func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %arg0 // CHECK-NOT: tosa.mul + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32> - %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } @@ -429,7 +430,8 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: return %arg0 // CHECK-NOT: tosa.mul %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %ones, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } @@ -465,11 +467,12 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>} // CHECK-NOT: tosa.mul %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %arg0, %zeros, %shift : (tensor<2x3xf32>, tensor<1x1xf32>, tensor<1xi8>) -> tensor<2x3xf32> // CHECK-NOT: tosa.mul // CHECK: return %[[ZERO]], %[[ZERO]] - %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor<2x3xf32>, tensor<1xi8>) -> tensor<2x3xf32> return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32> } @@ -478,8 +481,9 @@ func.func @mul_zero_broadcast_dynamic_result(%arg0: tensor) -> (tensor< // CHECK: tosa.mul // CHECK: tosa.mul %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32> - %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor, tensor<1x1xf32>) -> tensor - %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %arg0, %zeros, %shift : (tensor, tensor<1x1xf32>, tensor<1xi8>) -> tensor + %2 = tosa.mul %zeros, %arg0, %shift : (tensor<1x1xf32>, tensor, tensor<1xi8>) -> tensor return %1, %2 : tensor, tensor } @@ -1438,7 +1442,8 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform : tensor<1xi8>} : () -> tensor<1x!quant.uniform> %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform> - %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>)-> tensor<1x!quant.uniform> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = tosa.mul %0, %1, %shift : (tensor<1x!quant.uniform>, tensor<1x!quant.uniform>, tensor<1xi8>) -> tensor<1x!quant.uniform> return %2 : tensor<1x!quant.uniform> } @@ -1573,7 +1578,8 @@ func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) -> // CHECK: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 0x7F800000 : f32, min_val = 0.000000e+00 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: return [[VAR_0_]] : tensor<13x21x3xf32> %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}>: () -> tensor<1x1x1xf32> - %1 = tosa.mul %arg0, %0 {shift = 0 : i8}: (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %1 = tosa.mul %arg0, %0, %shift: (tensor<13x21x3xf32>, tensor<1x1x1xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xi1> %3 = tosa.select %2, %arg0, %1: ( tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %3 : tensor<13x21x3xf32> diff --git a/mlir/test/Dialect/Tosa/constant-mul-opt.mlir b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir index c00f01e152b36..dbd757d4768ec 100644 --- a/mlir/test/Dialect/Tosa/constant-mul-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-mul-opt.mlir @@ -15,7 +15,8 @@ func.func @mul_fold_float() -> tensor<4xf16> { dense<[-132.7, -3.0, -0.0, 5.0]> : tensor<4xf16> } : () -> tensor<4xf16> - %2 = "tosa.mul"(%0, %1) : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xf16>, tensor<4xf16>, tensor<1xi8>) -> tensor<4xf16> return %2 : tensor<4xf16> } @@ -32,7 +33,8 @@ func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> { dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> : tensor<7xf32> } : () -> tensor<7xf32> - %2 = "tosa.mul"(%0, %1) : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<7xf32>, tensor<7xf32>, tensor<1xi8>) -> tensor<7xf32> return %2 : tensor<7xf32> } @@ -49,7 +51,8 @@ func.func @add_fold_float_overflow() -> tensor<2xf32> { dense<[2.1e+38, 1.1e+38]> : tensor<2xf32> } : () -> tensor<2xf32> - %2 = "tosa.mul"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %2 = "tosa.mul"(%0, %1, %shift) : (tensor<2xf32>, tensor<2xf32>, tensor<1xi8>) -> tensor<2xf32> return %2 : tensor<2xf32> } @@ -195,3 +198,4 @@ func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> { %2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32> return %2 : tensor<4xi32> } + diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 532babf11fd6a..05a3f34b8db38 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -272,7 +272,8 @@ func.func @fold_div_splat_i32() -> tensor { func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> - %mul = tosa.mul %arg0, %zero : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %zero, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[ZERO]] return %mul : tensor } @@ -283,7 +284,8 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor) -> tensor { func.func @fold_mul_zero_lhs_f32(%arg0: tensor) -> tensor { %zero = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00> - %mul = tosa.mul %zero, %arg0 : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %zero, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %[[ZERO]] return %mul : tensor } @@ -317,7 +319,8 @@ func.func @fold_mul_zero_lhs_i32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_one_rhs_f32 func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { %one = "tosa.const"() {value = dense<1.0> : tensor} : () -> tensor - %mul = tosa.mul %arg0, %one : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %arg0, %one, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %arg0 return %mul : tensor } @@ -327,7 +330,8 @@ func.func @fold_mul_one_rhs_f32(%arg0: tensor) -> tensor { // CHECK-LABEL: @fold_mul_one_lhs_f32 func.func @fold_mul_one_lhs_f32(%arg0: tensor) -> tensor { %one = "tosa.const"() {value = dense<1.0> : tensor} : () -> tensor - %mul = tosa.mul %one, %arg0 : (tensor, tensor) -> tensor + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %arg0, %shift : (tensor, tensor, tensor<1xi8>) -> tensor // CHECK: return %arg0 return %mul : tensor } @@ -373,7 +377,8 @@ func.func @fold_mul_splat_i8() -> tensor<10xi32> { func.func @fold_mul_splat_f32() -> tensor<10xf32> { %one = "tosa.const"() {value = dense<3.0> : tensor<10xf32>} : () -> tensor<10xf32> %two = "tosa.const"() {value = dense<2.0> : tensor<10xf32>} : () -> tensor<10xf32> - %mul = tosa.mul %one, %two : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %mul = tosa.mul %one, %two, %shift : (tensor<10xf32>, tensor<10xf32>, tensor<1xi8>) -> tensor<10xf32> // CHECK: %[[THREE:.+]] = "tosa.const"() <{value = dense<6.000000e+00> : tensor<10xf32>} // CHECK: return %[[THREE]] return %mul : tensor<10xf32> diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 3e8bc7cfde8da..467fa2e7df11c 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -817,26 +817,27 @@ func.func @test_transpose_conv2d_invalid_outshape(%arg0: tensor<1x32x32x8xf32>, // CHECK-LABEL: test_mul_type_mismatch func.func @test_mul_type_mismatch(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf16>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> // expected-error@+1 {{'tosa.mul' op requires the same element type for all operands}} - %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf16>) -> tensor<13x21x3xf32> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf16>, tensor<1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } // ----- // CHECK-LABEL: test_mul_invalid_shift -func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { - %shift = "tosa.const"() {value = dense<0.0> : tensor} : () -> tensor - // expected-error@+1 {{'tosa.mul' op operand #2 must be 1D tensor of 8-bit signless integer values, but got 'tensor'}} - %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor) -> tensor<13x21x3xi32> - return %0 : tensor<13x21x3xi32> +func.func @test_mul_invalid_shift(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op require shift to be 0 for float type}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> } // ----- // CHECK-LABEL: test_mul_missing_shift func.func @test_mul_missing_shift(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { - // this is ok because mul's shift operand is optional for now + // expected-error@+1 {{'tosa.mul' op expected 3 operands, but found 2}} %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> return %0 : tensor<13x21x3xi32> } @@ -1148,3 +1149,30 @@ func.func @test_sub_with_unequal_result_ranks(%arg0: tensor<1x21x3xf32>, %arg1: %0 = tosa.sub %arg0, %arg1 : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> return %0 : tensor<1x13x21x3xf32> } + +// ----- +// CHECK-LABEL: test_mul_non_scalar_shift_2d +func.func @test_mul_non_scalar_shift_2d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1x1xi8>}> : () -> tensor<1x1xi8> + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<1x1xi8>'}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1x1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_non_scalar_shift_1d +func.func @test_mul_non_scalar_shift_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<2xi8>}> : () -> tensor<2xi8> + // expected-error@+1 {{'tosa.mul' op operand #2 must be tosa-conformant scalar tensor of 8-bit signless integer values, but got 'tensor<2xi8>'}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<2xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- +// CHECK-LABEL: test_mul_non_broadcast +func.func @test_mul_non_broadcast(%arg0: tensor<13x21x2xf32>, %arg1: tensor<3x1x3xf32>) -> tensor<13x21x3xf32> { + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // expected-error@+1 {{'tosa.mul' op operands don't have broadcast-compatible shapes}} + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x2xf32>, tensor<3x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index bbbcb735e613e..4ac623578ce8b 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -355,7 +355,8 @@ func.func @test_mul_scalar_with_unranked_output(%arg0: tensor, %arg1: tenso // ----- // CHECK-LABEL: mul func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.mul %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> + %shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + %0 = tosa.mul %arg0, %arg1, %shift : (tensor<13x21x3xf32>, tensor<13x1x3xf32>, tensor<1xi8>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index b8988c2b5728b..6e345556ff359 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -114,23 +114,24 @@ func.func @test_binary_scalar_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32>) // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %3 = tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %3 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32> + %4 = tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %5 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %6 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %7 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %8 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %9 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> return } @@ -148,23 +149,24 @@ func.func @test_binary_broadcast_f32(%arg0 : tensor<4xf32>, %arg1 : tensor<1xf32 // CHECK: tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> %2 = tosa.minimum %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> - // CHECK: tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %3 = tosa.mul %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %3 = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> + // CHECK: tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<4xf32> + %4 = tosa.mul %arg0, %arg1, %3 : (tensor<4xf32>, tensor<1xf32>, tensor<1xi8>) -> tensor<*xf32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %4 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %5 = tosa.pow %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> - %5 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> + %6 = tosa.sub %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xf32> // CHECK: tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %6 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %7 = tosa.equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %7 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %8 = tosa.greater %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> // CHECK: tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xi1> - %8 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> + %9 = tosa.greater_equal %arg0, %arg1 : (tensor<4xf32>, tensor<1xf32>) -> tensor<*xi1> return } @@ -211,10 +213,10 @@ func.func @test_binary_i32(%arg0 : tensor<4xi32>, %arg1 : tensor<1xi32>) -> () { %11 = tosa.mul %arg0, %arg1, %shift : (tensor<4xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<*xi32> // CHECK: tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %12 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + %13 = tosa.pow %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> // CHECK: tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> - %13 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> + %14 = tosa.sub %arg0, %arg1 : (tensor<4xi32>, tensor<1xi32>) -> tensor<*xi32> return } diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir index e70f3644da646..3a293009a5455 100644 --- a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir @@ -193,6 +193,7 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x // CHECK-LABEL: @test_resnet18_common_case // COM: note that %74 is now represented by %arg2 +// CHECK-DAG: %[[CONST0:.+]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> // CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> // CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<9.99999974E-6> : tensor<1xf32>}> : () -> tensor<1xf32> @@ -205,15 +206,16 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x // CHECK-DAG: %[[VAL_12:.*]] = tosa.sub %arg2, %[[VAL_11]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_14:.*]] = tosa.reshape %[[VAL_9]], %[[VAL_13]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_15:.*]] = tosa.mul %[[VAL_12]], %[[VAL_14]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_16:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_17:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_16]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> -// CHECK-DAG: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_18:.*]] = tosa.mul %[[VAL_15]], %[[VAL_17]], %[[CONST0]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>, tensor<1xi8>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_19]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> // CHECK-DAG: %[[VAL_21:.*]] = tosa.add %[[VAL_18]], %[[VAL_20]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> // CHECK-DAG: %[[VAL_22:.*]] = tosa.clamp %[[VAL_21]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> { + %shift = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1xi8> %58 = tosa.const_shape {value = dense<[1, 64, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> %59 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> %60 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> @@ -228,9 +230,9 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32 %79 = tosa.reshape %arg0, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %80 = tosa.sub %75, %79 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> %81 = tosa.reshape %78, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> - %82 = tosa.mul %80, %81 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> + %82 = tosa.mul %80, %81, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32> %83 = tosa.reshape %60, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> - %84 = tosa.mul %82, %83 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> + %84 = tosa.mul %82, %83, %shift : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>, tensor<1xi8>) -> tensor<1x64x112x112xf32> %85 = tosa.reshape %59, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> %87 = tosa.clamp %86 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32>