diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index ceb86a3ee34d9..21451c83a2cc1 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -394,10 +394,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> { let arguments = (ins Tosa_Tensor:$input, - I64Attr:$min_int, - I64Attr:$max_int, - Tosa_FloatAttr:$min_fp, - Tosa_FloatAttr:$max_fp, + Tosa_IntOrFloatAttr:$min_val, + Tosa_IntOrFloatAttr:$max_val, DefaultValuedAttr:$nan_mode ); diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td index 08a3c7b46d395..28b2f90c90052 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td @@ -205,6 +205,14 @@ def Tosa_FloatAttr : Attr($_self)">, let returnType = [{ ::mlir::APFloat }]; } +def Tosa_IntegerAttr : Attr($_self)">, + "arbitrary integer attribute"> { + let storageType = [{ ::mlir::IntegerAttr }]; + let returnType = [{ ::llvm::APInt }]; +} + +def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>; + //===----------------------------------------------------------------------===// // Iterable attributes. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 5064194226b71..2c291fc12430c 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -407,8 +407,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( // tosa::ClampOp if (isa(op) && isa(elementTy)) { bool losesInfo = false; - APFloat minApf = cast(op->getAttr("min_fp")).getValue(); - APFloat maxApf = cast(op->getAttr("max_fp")).getValue(); + APFloat minApf = cast(op->getAttr("min_val")).getValue(); + APFloat maxApf = cast(op->getAttr("max_val")).getValue(); minApf.convert(cast(elementTy).getFloatSemantics(), APFloat::rmNearestTiesToEven, &losesInfo); maxApf.convert(cast(elementTy).getFloatSemantics(), @@ -423,9 +423,9 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (isa(op) && isa(elementTy)) { auto intTy = cast(elementTy); int64_t min = - cast(op->getAttr("min_int")).getValue().getSExtValue(); + cast(op->getAttr("min_val")).getValue().getSExtValue(); int64_t max = - cast(op->getAttr("max_int")).getValue().getSExtValue(); + cast(op->getAttr("max_val")).getValue().getSExtValue(); int64_t minRepresentable = std::numeric_limits::min(); int64_t maxRepresentable = std::numeric_limits::max(); diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 846e2986129a4..b70b7cf30b1a2 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -302,6 +302,9 @@ struct SelectToClampOptimization : public OpRewritePattern { DenseElementsAttr onFalseAttr; DenseElementsAttr onTrueAttr; + const Type resultElemTy = op.getType().getElementType(); + const bool resultElemTyIsUnsignedInteger = resultElemTy.isUnsignedInteger(); + // Case one: // %0 = tosa.greater_equal(input, cmp) // %1 = tosa.select(%0, input, cmp) @@ -328,10 +331,8 @@ struct SelectToClampOptimization : public OpRewritePattern { } const auto inputElementType = geqIn2Attr.getElementType(); - int64_t clampIntMin = std::numeric_limits::min(); - int64_t clampIntMax = std::numeric_limits::max(); - FloatAttr clampFloatMin; - FloatAttr clampFloatMax; + Attribute clampMin; + Attribute clampMax; if (auto integerType = dyn_cast(inputElementType)) { int64_t splatValue; if (integerType.isUnsigned()) { @@ -343,26 +344,33 @@ struct SelectToClampOptimization : public OpRewritePattern { } else { splatValue = geqIn2Attr.getSplatValue().getSExtValue(); } - clampFloatMin = - rewriter.getF32FloatAttr(-std::numeric_limits::infinity()); - clampFloatMax = - rewriter.getF32FloatAttr(std::numeric_limits::infinity()); if (isCaseOne) { - clampIntMin = splatValue; + clampMin = rewriter.getIntegerAttr(resultElemTy, splatValue); + clampMax = rewriter.getIntegerAttr( + resultElemTy, + resultElemTyIsUnsignedInteger + ? APInt::getMaxValue(resultElemTy.getIntOrFloatBitWidth()) + : APInt::getSignedMaxValue( + resultElemTy.getIntOrFloatBitWidth())); } else { - clampIntMax = splatValue; + clampMax = rewriter.getIntegerAttr(resultElemTy, splatValue); + clampMin = rewriter.getIntegerAttr( + resultElemTy, + resultElemTyIsUnsignedInteger + ? APInt::getMinValue(resultElemTy.getIntOrFloatBitWidth()) + : APInt::getSignedMinValue( + resultElemTy.getIntOrFloatBitWidth())); } } else if (isa(inputElementType)) { auto splatValue = geqIn2Attr.getSplatValue(); if (isCaseOne) { - clampFloatMin = rewriter.getFloatAttr(inputElementType, splatValue); - clampFloatMax = rewriter.getFloatAttr( - inputElementType, - APFloat::getInf(splatValue.getSemantics(), false)); + clampMin = rewriter.getFloatAttr(resultElemTy, splatValue); + clampMax = rewriter.getFloatAttr( + resultElemTy, APFloat::getInf(splatValue.getSemantics(), false)); } else { - clampFloatMin = rewriter.getFloatAttr( - inputElementType, APFloat::getInf(splatValue.getSemantics(), true)); - clampFloatMax = rewriter.getFloatAttr(inputElementType, splatValue); + clampMin = rewriter.getFloatAttr( + resultElemTy, APFloat::getInf(splatValue.getSemantics(), true)); + clampMax = rewriter.getFloatAttr(resultElemTy, splatValue); } } @@ -380,9 +388,8 @@ struct SelectToClampOptimization : public OpRewritePattern { input); } - rewriter.replaceOpWithNewOp( - op, op.getType(), input, rewriter.getI64IntegerAttr(clampIntMin), - rewriter.getI64IntegerAttr(clampIntMax), clampFloatMin, clampFloatMax); + rewriter.replaceOpWithNewOp(op, op.getType(), input, + clampMin, clampMax); return success(); } @@ -606,10 +613,12 @@ struct ClampIsNoOp : public OpRewritePattern { if (isa(inputElementType)) { // Unlike integer types, floating point types can represent infinity. - auto minClamp = op.getMinFp(); - auto maxClamp = op.getMaxFp(); - bool isMin = minClamp.isInfinity() && minClamp.isNegative(); - bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative(); + auto minClamp = + llvm::cast(op.getMinValAttr()).getValue(); + auto maxClamp = + llvm::cast(op.getMaxValAttr()).getValue(); + bool isMin = minClamp.isNegInfinity(); + bool isMax = maxClamp.isInfinity(); if (isMin && isMax) { rewriter.replaceOp(op, input); @@ -619,8 +628,10 @@ struct ClampIsNoOp : public OpRewritePattern { } if (inputElementType.isUnsignedInteger()) { - int64_t minClamp = op.getMinInt(); - int64_t maxClamp = op.getMaxInt(); + int64_t minClamp = + llvm::cast(op.getMinValAttr()).getUInt(); + int64_t maxClamp = + llvm::cast(op.getMaxValAttr()).getUInt(); int64_t intMin = APInt::getMinValue(inputElementType.getIntOrFloatBitWidth()) @@ -637,8 +648,10 @@ struct ClampIsNoOp : public OpRewritePattern { } if (llvm::isa(inputElementType)) { - int64_t minClamp = op.getMinInt(); - int64_t maxClamp = op.getMaxInt(); + int64_t minClamp = + llvm::cast(op.getMinValAttr()).getInt(); + int64_t maxClamp = + llvm::cast(op.getMaxValAttr()).getInt(); int64_t intMin = APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth()) @@ -693,9 +706,10 @@ struct ClampClampOptimization : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::ClampOp op, PatternRewriter &rewriter) const override { + Value input = op.getInput(); + // Check the input to the CLAMP op is itself a CLAMP. - auto clampOp = - dyn_cast_if_present(op.getInput().getDefiningOp()); + auto clampOp = dyn_cast_if_present(input.getDefiningOp()); if (!clampOp) return failure(); @@ -705,34 +719,87 @@ struct ClampClampOptimization : public OpRewritePattern { if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE") return failure(); - // Check we have intersecting ranges. - const auto opMinInt = op.getMinInt(); - const auto opMaxInt = op.getMaxInt(); - const auto clampOpMinInt = clampOp.getMinInt(); - const auto clampOpMaxInt = clampOp.getMaxInt(); - ClampRange opRangeIntRange(opMinInt, opMaxInt); - ClampRange clampRangeIntRange(clampOpMinInt, clampOpMaxInt); - if (!opRangeIntRange.intersects(clampRangeIntRange)) - return failure(); + auto maxValAttr = op.getMaxValAttr(); + auto minValAttr = op.getMinValAttr(); + auto clampOpMaxValAttr = clampOp.getMaxValAttr(); + auto clampOpMinValAttr = clampOp.getMinValAttr(); - const auto opMinFloat = op.getMinFp(); - const auto opMaxFloat = op.getMaxFp(); - const auto clampOpMinFloat = clampOp.getMinFp(); - const auto clampOpMaxFloat = clampOp.getMaxFp(); - ClampRange opRangeFloatRange(opMinFloat, opMaxFloat); - ClampRange clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat); - if (!opRangeFloatRange.intersects(clampRangeFloatRange)) - return failure(); + auto inputEType = llvm::cast(input.getType()).getElementType(); + if (auto quantType = + llvm::dyn_cast(inputEType)) { + inputEType = quantType.getStorageType(); + } + + Attribute newMinValAttr, newMaxValAttr; + if (mlir::isa(inputEType)) { + auto floatMaxValAttr = cast(maxValAttr); + auto floatMinValAttr = cast(minValAttr); + auto clampOpFloatMaxValAttr = cast(clampOpMaxValAttr); + auto clampOpFloatMinValAttr = cast(clampOpMinValAttr); + + // Check we have intersecting ranges. + const auto opMinFloat = floatMinValAttr.getValue(); + const auto opMaxFloat = floatMaxValAttr.getValue(); + const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue(); + const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue(); + ClampRange opRangeFloatRange(opMinFloat, opMaxFloat); + ClampRange clampRangeFloatRange(clampOpMinFloat, + clampOpMaxFloat); + if (!opRangeFloatRange.intersects(clampRangeFloatRange)) + return failure(); + + // Run the transformation. + auto newMinVal = std::max(opMinFloat, clampOpMinFloat); + auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat); + newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal); + newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal); + } else { + assert(mlir::isa(inputEType)); + auto intMaxValAttr = cast(maxValAttr); + auto intMinValAttr = cast(minValAttr); + auto clampOpIntMaxValAttr = cast(clampOpMaxValAttr); + auto clampOpIntMinValAttr = cast(clampOpMinValAttr); + + if (inputEType.isUnsignedInteger()) { + // Check we have intersecting ranges. + const auto opMinInt = intMinValAttr.getUInt(); + const auto opMaxInt = intMaxValAttr.getUInt(); + const auto clampOpMinInt = clampOpIntMinValAttr.getUInt(); + const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt(); + ClampRange opRangeIntRange(opMinInt, opMaxInt); + ClampRange clampRangeIntRange(clampOpMinInt, + clampOpMaxInt); + if (!opRangeIntRange.intersects(clampRangeIntRange)) + return failure(); + + // Run the transformation. + auto newMinVal = std::max(opMinInt, clampOpMinInt); + auto newMaxVal = std::min(opMaxInt, clampOpMaxInt); + newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal); + newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal); + } else { + // Check we have intersecting ranges. + const auto opMinInt = intMinValAttr.getInt(); + const auto opMaxInt = intMaxValAttr.getInt(); + const auto clampOpMinInt = clampOpIntMinValAttr.getInt(); + const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt(); + ClampRange opRangeIntRange(opMinInt, opMaxInt); + ClampRange clampRangeIntRange(clampOpMinInt, + clampOpMaxInt); + if (!opRangeIntRange.intersects(clampRangeIntRange)) + return failure(); + + // Run the transformation. + auto newMinVal = std::max(opMinInt, clampOpMinInt); + auto newMaxVal = std::min(opMaxInt, clampOpMaxInt); + newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal); + newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal); + } + } - // Run the transformation. - const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat(); - const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat(); - const auto minInt = std::max(opMinInt, clampOpMinInt); - const auto maxInt = std::min(opMaxInt, clampOpMaxInt); rewriter.replaceOpWithNewOp( op, {op->getLoc(), clampOp->getLoc()}, op.getType(), clampOp.getInput(), - rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt), - rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp), + newMinValAttr, newMaxValAttr, rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE" : opNanMode)); return success(); @@ -973,25 +1040,28 @@ struct MinToClampOptimization : public OpRewritePattern { Value input = op.getInput1(); auto elementTy = llvm::cast(input.getType()).getElementType(); - int64_t minInt = std::numeric_limits::min(); - float minFp = std::numeric_limits::lowest(); - - int64_t maxInt; - float maxFp; - if (isa(elementTy)) { + Attribute minAttr; + Attribute maxAttr; + if (auto floatTy = dyn_cast(elementTy)) { auto constMin = constant.getSplatValue(); - maxFp = constMin.convertToFloat(); - maxInt = constMin.convertToFloat(); + maxAttr = rewriter.getFloatAttr(floatTy, constMin); + minAttr = rewriter.getFloatAttr( + floatTy, APFloat::getInf(constMin.getSemantics(), /*Negative=*/true)); + } else if (auto intTy = cast(elementTy); + intTy.isUnsignedInteger()) { + auto constMin = constant.getSplatValue(); + maxAttr = rewriter.getIntegerAttr(intTy, constMin); + minAttr = + rewriter.getIntegerAttr(intTy, APInt::getMinValue(intTy.getWidth())); } else { auto constMin = constant.getSplatValue(); - maxFp = constMin.getSExtValue(); - maxInt = constMin.getSExtValue(); + maxAttr = rewriter.getIntegerAttr(intTy, constMin); + minAttr = rewriter.getIntegerAttr( + intTy, APInt::getSignedMinValue(intTy.getWidth())); } - rewriter.replaceOpWithNewOp( - op, op.getType(), input, rewriter.getI64IntegerAttr(minInt), - rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), - rewriter.getF32FloatAttr(maxFp)); + rewriter.replaceOpWithNewOp(op, op.getType(), input, minAttr, + maxAttr); return success(); } @@ -1016,25 +1086,28 @@ struct MaxToClampOptimization : public OpRewritePattern { Value input = op.getInput1(); auto elementTy = llvm::cast(input.getType()).getElementType(); - int64_t maxInt = std::numeric_limits::max(); - float maxFp = std::numeric_limits::max(); - - int64_t minInt; - float minFp; - if (isa(elementTy)) { + Attribute minAttr; + Attribute maxAttr; + if (auto floatTy = dyn_cast(elementTy)) { auto constMax = constant.getSplatValue(); - minFp = constMax.convertToFloat(); - minInt = constMax.convertToFloat(); + minAttr = rewriter.getFloatAttr(floatTy, constMax); + maxAttr = rewriter.getFloatAttr(floatTy, + APFloat::getInf(constMax.getSemantics())); + } else if (auto intTy = cast(elementTy); + intTy.isUnsignedInteger()) { + auto constMax = constant.getSplatValue(); + minAttr = rewriter.getIntegerAttr(intTy, constMax); + maxAttr = + rewriter.getIntegerAttr(intTy, APInt::getMaxValue(intTy.getWidth())); } else { auto constMax = constant.getSplatValue(); - minFp = constMax.getSExtValue(); - minInt = constMax.getSExtValue(); + minAttr = rewriter.getIntegerAttr(intTy, constMax); + maxAttr = rewriter.getIntegerAttr( + intTy, APInt::getSignedMaxValue(intTy.getWidth())); } - rewriter.replaceOpWithNewOp( - op, op.getType(), input, rewriter.getI64IntegerAttr(minInt), - rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp), - rewriter.getF32FloatAttr(maxFp)); + rewriter.replaceOpWithNewOp(op, op.getType(), input, minAttr, + maxAttr); return success(); } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index b667d61ae1d1f..f1fcc17cf360f 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -476,26 +476,40 @@ LogicalResult tosa::ClampOp::verify() { llvm::dyn_cast(inputETy)) { inputETy = quantType.getStorageType(); } - mlir::Type maxFpType = getMaxFpAttr().getType(); - mlir::Type minFpType = getMinFpAttr().getType(); mlir::Type outputETy = llvm::cast(getOutput().getType()).getElementType(); if (auto quantType = llvm::dyn_cast(outputETy)) { outputETy = quantType.getStorageType(); } - unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth(); - if (inputETy != outputETy) return emitOpError("input/output element types are incompatible."); - // If input datatype is float, check that the two min/max_fp attributes - // share the same type and that their type is either the same of the input's - // datatype, or a float type whose bitwidth > input datatype bitwidth. - if (!inputETy.isInteger(dataTypeBitWidth)) { - if (((maxFpType != minFpType) || - (maxFpType != inputETy && maxFpType.getIntOrFloatBitWidth() <= - inputETy.getIntOrFloatBitWidth()))) + auto maxValAttr = getMaxValAttr(); + auto minValAttr = getMinValAttr(); + + unsigned dataTypeBitWidth = inputETy.getIntOrFloatBitWidth(); + + if (inputETy.isInteger(dataTypeBitWidth)) { + // if input datatype is integer, check that the min_val/max_val attributes + // are integer attributes, and that their type is the same as the input's + // datatype + auto intMaxValAttr = mlir::dyn_cast(maxValAttr); + auto intMinValAttr = mlir::dyn_cast(minValAttr); + if (!intMaxValAttr || !intMinValAttr || + (intMaxValAttr.getType() != intMinValAttr.getType()) || + (intMaxValAttr.getType() != inputETy)) + return emitOpError("min/max attributes types are incompatible with " + "input/output element types."); + } else { + // otherwise, input datatype is float, check that the min_val/max_val + // attributes share the same type and that their type is the same as the + // input's datatype + auto floatMaxValAttr = mlir::dyn_cast(maxValAttr); + auto floatMinValAttr = mlir::dyn_cast(minValAttr); + if (!floatMaxValAttr || !floatMinValAttr || + (floatMaxValAttr.getType() != floatMinValAttr.getType()) || + (floatMaxValAttr.getType() != inputETy)) return emitOpError("min/max attributes types are incompatible with " "input/output element types."); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp index df47dc89d5332..58b0d31bc403c 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp @@ -888,8 +888,8 @@ struct TosaFoldConstantClamp return {}; } - auto lowerBoundVal = op.getMinIntAttr().getValue(); - auto upperBoundVal = op.getMaxIntAttr().getValue(); + auto lowerBoundVal = cast(op.getMinValAttr()).getValue(); + auto upperBoundVal = cast(op.getMaxValAttr()).getValue(); assert(lowerBoundVal.getBitWidth() == upperBoundVal.getBitWidth()); return applyClamp(values, lowerBoundVal, upperBoundVal, op.getType()); @@ -898,8 +898,8 @@ struct TosaFoldConstantClamp /// Called when the values.getElementType() is FloatType. DenseElementsAttr computeFloat(DenseElementsAttr values, PatternRewriter &rewriter, ClampOp op) const { - auto lowerBoundVal = op.getMinFp(); - auto upperBoundVal = op.getMaxFp(); + auto lowerBoundVal = cast(op.getMinValAttr()).getValue(); + auto upperBoundVal = cast(op.getMaxValAttr()).getValue(); assert(APFloat::getSizeInBits(lowerBoundVal.getSemantics()) == APFloat::getSizeInBits(upperBoundVal.getSemantics())); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 2780a56a6f4cf..a0df696c53b2d 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -549,7 +549,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () { // CHECK: linalg.generic // CHECK: arith.minimumf // CHECK: arith.maximumf - %18 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> + %18 = tosa.clamp %0 {min_val = 1.0 : f32, max_val = 5.0 : f32} : (tensor<1xf32>) -> tensor<1xf32> // CHECK: linalg.generic // CHECK: arith.negf @@ -795,35 +795,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns // CHECK: linalg.generic // CHECK-DAG: arith.maxsi // CHECK-DAG: arith.minsi - %19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32> + %19 = tosa.clamp %0 {min_val = 1 : i32, max_val = 5 : i32} : (tensor<1xi32>) -> tensor<1xi32> // CHECK: linalg.generic // CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32 // CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32 // CHECK-DAG: arith.maxui %[[LB]], // CHECK-DAG: arith.minui %[[UB]], - %u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32> - - // CHECK: linalg.generic - // CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32 - // CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32 - // CHECK-DAG: arith.maxui %[[LB]], - // CHECK-DAG: arith.minui %[[UB]], - %u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32> - - // CHECK: linalg.generic - // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32 - // CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32 - // CHECK-DAG: arith.maxui %[[LB]], - // CHECK-DAG: arith.minui %[[UB]], - %u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32> - - // CHECK: linalg.generic - // CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64 - // CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64 - // CHECK-DAG: arith.maxui %[[LB]], - // CHECK-DAG: arith.minui %[[UB]], - %u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64> + %u0 = tosa.clamp %unsigned {min_val = 4 : ui32, max_val = 32 : ui32} : (tensor<1xui32>) -> tensor<1xui32> // CHECK: linalg.generic // CHECK: arith.trunci @@ -877,15 +856,7 @@ func.func @test_i8(%arg0: tensor<1xi8>) -> () { // CHECK-DAG: %[[C126:.+]] = arith.constant 126 // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]] // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]] - %0 = tosa.clamp %arg0 {min_int = -127 : i64, max_int = 126 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> - - // CHECK: linalg.generic - // CHECK: ^bb0(%[[ARG1:.+]]: i8, - // CHECK-DAG: %[[C128:.+]] = arith.constant -128 - // CHECK-DAG: %[[C127:.+]] = arith.constant 127 - // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C128]], %[[ARG1]] - // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C127]], %[[LOWER]] - %1 = tosa.clamp %arg0 {min_int = -130 : i64, max_int = 130 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi8>) -> tensor<1xi8> + %0 = tosa.clamp %arg0 {min_val = -127 : i8, max_val = 126 : i8} : (tensor<1xi8>) -> tensor<1xi8> return } @@ -900,7 +871,7 @@ func.func @test_i64(%arg0: tensor<1xi64>) -> () { // CHECK-DAG: %[[C126:.+]] = arith.constant 9223372036854775807 // CHECK-DAG: %[[LOWER:.+]] = arith.maxsi %[[C127]], %[[ARG1]] // CHECK-DAG: %[[CLAMPED:.+]] = arith.minsi %[[C126]], %[[LOWER]] - %0 = tosa.clamp %arg0 {min_int = -9223372036854775808 : i64, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, max_fp = 0.0 : f32} : (tensor<1xi64>) -> tensor<1xi64> + %0 = tosa.clamp %arg0 {min_val = -9223372036854775808 : i64, max_val = 9223372036854775807 : i64} : (tensor<1xi64>) -> tensor<1xi64> return } @@ -915,7 +886,7 @@ func.func @test_clamp_f16(%arg0: tensor<1xf16>) -> () { // CHECK-DAG: %[[C6:.+]] = arith.constant 6.0 // CHECK-DAG: %[[MIN:.+]] = arith.minimumf %[[ARG1]], %[[C6]] // CHECK-DAG: %[[MAX:.+]] = arith.maximumf %[[MIN]], %[[C0]] - %0 = tosa.clamp %arg0 {min_int = 0 : i64, max_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 6.0 : f32} : (tensor<1xf16>) -> tensor<1xf16> + %0 = tosa.clamp %arg0 {min_val = 0.0 : f16, max_val = 6.0 : f16} : (tensor<1xf16>) -> tensor<1xf16> return } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 362cabcfd8fdd..c4e655d6d25ce 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -86,25 +86,16 @@ func.func @cast_no_fold_double2(%arg0: tensor) -> tensor { // CHECK-LABEL: @clamp_i32_not_noop func.func @clamp_i32_not_noop(%arg0: tensor<4xi32>) -> tensor<4xi32> { // CHECK: tosa.clamp - %0 = tosa.clamp %arg0 {min_int = 1 : i64, max_int = 4 : i64, min_fp = 1.0 : f32, max_fp = 4.0 : f32} : (tensor<4xi32>) -> tensor<4xi32> + %0 = tosa.clamp %arg0 {min_val = 1 : i32, max_val = 4 : i32} : (tensor<4xi32>) -> tensor<4xi32> return %0 : tensor<4xi32> } // ----- -// CHECK-LABEL: @clamp_f16_not_noop -func.func @clamp_f16_not_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> { - // CHECK: tosa.clamp - %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf16>) -> tensor<4xf16> - return %0 : tensor<4xf16> -} - -// ----- - // CHECK-LABEL: @clamp_f32_not_noop func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: tosa.clamp - %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xf32>) -> tensor<4xf32> + %0 = tosa.clamp %arg0 {min_val = -3.40282347E+38 : f32, max_val = 3.40282347E+38 : f32} : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -114,8 +105,8 @@ func.func @clamp_f32_not_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { func.func @clamp_f16_is_noop(%arg0: tensor<4xf16>) -> tensor<4xf16> { // CHECK: return %arg0 // CHECK-NOT: "tosa.clamp" - // 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity. - %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf16>) -> tensor<4xf16> + // 0x7C00 and 0xFC00 are respectively positive and negative F32 infinity. + %0 = tosa.clamp %arg0 {max_val = 0x7C00 : f16, min_val = 0xFC00 : f16} : (tensor<4xf16>) -> tensor<4xf16> return %0 : tensor<4xf16> } @@ -126,7 +117,7 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { // CHECK: return %arg0 // CHECK-NOT: "tosa.clamp" // 0xFF800000 and 0x7F800000 are respectively negative and positive F32 infinity. - %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = 0xFF800000 : f32, max_fp = 0x7F800000 : f32} : (tensor<4xf32>) -> tensor<4xf32> + %0 = tosa.clamp %arg0 {min_val = 0xFF800000 : f32, max_val = 0x7F800000 : f32} : (tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -136,7 +127,7 @@ func.func @clamp_f32_is_noop(%arg0: tensor<4xf32>) -> tensor<4xf32> { func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> { // CHECK: return %arg0 // CHECK-NOT: tosa.clamp - %0 = tosa.clamp %arg0 {min_int = -128 : i64, max_int = 127 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xi8>) -> tensor<4xi8> + %0 = tosa.clamp %arg0 {min_val = -128 : i8, max_val = 127 : i8} : (tensor<4xi8>) -> tensor<4xi8> return %0 : tensor<4xi8> } @@ -146,7 +137,7 @@ func.func @clamp_int8_is_noop(%arg0: tensor<4xi8>) -> tensor<4xi8> { func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> { // CHECK: return %arg0 // CHECK-NOT: tosa.clamp - %0 = tosa.clamp %arg0 {min_int = -32768 : i64, max_int = 32767 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xi16>) -> tensor<4xi16> + %0 = tosa.clamp %arg0 {min_val = -32768 : i16, max_val = 32767 : i16} : (tensor<4xi16>) -> tensor<4xi16> return %0 : tensor<4xi16> } @@ -156,7 +147,7 @@ func.func @clamp_int16_is_noop(%arg0: tensor<4xi16>) -> tensor<4xi16> { func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> { // CHECK: return %arg0 // CHECK-NOT: tosa.clamp - %0 = tosa.clamp %arg0 {min_int = 0 : i64, max_int = 255 : i64, min_fp = -3.40282347E+38 : f32, max_fp = 3.40282347E+38 : f32} : (tensor<4xui8>) -> tensor<4xui8> + %0 = tosa.clamp %arg0 {min_val = 0 : ui8, max_val = 255 : ui8} : (tensor<4xui8>) -> tensor<4xui8> return %0 : tensor<4xui8> } @@ -164,35 +155,35 @@ func.func @clamp_uint8_is_noop(%arg0: tensor<4xui8>) -> tensor<4xui8> { // CHECK-LABEL: @clamp_twice_is_single_clamp func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { - // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64} - %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8> - %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8} + %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8} : (tensor<4xi8>) -> tensor<4xi8> return %1 : tensor<4xi8> } func.func @clamp_minimum_i32(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: tosa.clamp %arg0 {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64} + // CHECK: tosa.clamp %arg0 {max_val = 6 : i32, min_val = -2147483648 : i32} %0 = "tosa.const"() <{value = dense<6> : tensor<1xi32>}> : () -> tensor<1xi32> %1 = tosa.minimum %arg0, %0 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> return %1 : tensor<4xi32> } func.func @clamp_minimum_f32(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: tosa.clamp %arg0 {max_fp = 6.000000e+00 : f32, max_int = 6 : i64, min_fp = -3.40282347E+38 : f32, min_int = -2147483648 : i64} + // CHECK: tosa.clamp %arg0 {max_val = 6.000000e+00 : f32, min_val = 0xFF800000 : f32} %0 = "tosa.const"() <{value = dense<6.0> : tensor<1xf32>}> : () -> tensor<1xf32> %1 = tosa.minimum %arg0, %0 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> } func.func @clamp_maximum_i32(%arg0: tensor<4xi32>) -> tensor<4xi32> { - // CHECK: tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = -6.000000e+00 : f32, min_int = -6 : i64} + // CHECK: tosa.clamp %arg0 {max_val = 2147483647 : i32, min_val = -6 : i32} %0 = "tosa.const"() <{value = dense<-6> : tensor<1xi32>}> : () -> tensor<1xi32> %1 = tosa.maximum %arg0, %0 : (tensor<4xi32>, tensor<1xi32>) -> tensor<4xi32> return %1 : tensor<4xi32> } func.func @clamp_maximum_f32(%arg0: tensor<4xf32>) -> tensor<4xf32> { - // CHECK: tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = -6.000000e+00 : f32, min_int = -6 : i64} + // CHECK: tosa.clamp %arg0 {max_val = 0x7F800000 : f32, min_val = -6.000000e+00 : f32} %0 = "tosa.const"() <{value = dense<-6.0> : tensor<1xf32>}> : () -> tensor<1xf32> %1 = tosa.maximum %arg0, %0 : (tensor<4xf32>, tensor<1xf32>) -> tensor<4xf32> return %1 : tensor<4xf32> @@ -235,10 +226,10 @@ func.func @concat_fold_zero_size(%arg0: tensor, %arg1: tensor, // CHECK: @disjoint_clamp_twice_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>) func.func @disjoint_clamp_twice_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { - // CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = -5.000000e+00 : f32, max_int = -5 : i64, min_fp = -1.000000e+00 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8> - // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 5 : i64, min_fp = 1.000000e+00 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8> - %0 = tosa.clamp %arg0 {max_fp = -5.0 : f32, max_int = -5 : i64, min_fp = -1.0 : f32, min_int = -10 : i64} : (tensor<4xi8>) -> tensor<4xi8> - %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 5 : i64, min_fp = 1.0 : f32, min_int = 1 : i64} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_val = -5 : i8, min_val = -10 : i8} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 5 : i8, min_val = 1 : i8} : (tensor<4xi8>) -> tensor<4xi8> + %0 = tosa.clamp %arg0 {max_val = -5 : i8, min_val = -10 : i8} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_val = 5 : i8, min_val = 1 : i8} : (tensor<4xi8>) -> tensor<4xi8> return %1 : tensor<4xi8> } @@ -246,9 +237,9 @@ func.func @disjoint_clamp_twice_is_not_single_clamp(%arg0: tensor<4xi8>) -> tens // CHECK-LABEL: @clamp_twice_with_nan_propagate_is_single_clamp func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { - // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64} - %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> - %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8} + %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> return %1 : tensor<4xi8> } @@ -256,9 +247,9 @@ func.func @clamp_twice_with_nan_propagate_is_single_clamp(%arg0: tensor<4xi8>) - // CHECK-LABEL: @clamp_twice_with_nan_ignore_is_single_clamp func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { - // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} - %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> - %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8, nan_mode = "IGNORE"} + %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> return %1 : tensor<4xi8> } @@ -266,9 +257,9 @@ func.func @clamp_twice_with_nan_ignore_is_single_clamp(%arg0: tensor<4xi8>) -> t // CHECK-LABEL: @clamp_twice_with_nan_ignore_propagate_is_single_clamp func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { - // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} - %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> - %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8, nan_mode = "IGNORE"} + %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> return %1 : tensor<4xi8> } @@ -276,10 +267,10 @@ func.func @clamp_twice_with_nan_ignore_propagate_is_single_clamp(%arg0: tensor<4 // CHECK: @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%[[INPUT:.*]]: tensor<4xi8>) func.func @clamp_twice_with_nan_propagate_ignore_is_not_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { - // CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_fp = 3.000000e+00 : f32, max_int = 4 : i64, min_fp = -5.000000e+00 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8> - // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_fp = 5.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> - %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> - %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK: %[[CLAMP_1:.*]] = tosa.clamp %[[INPUT]] {max_val = 4 : i8, min_val = -2 : i8} : (tensor<4xi8>) -> tensor<4xi8> + // CHECK-NEXT: tosa.clamp %[[CLAMP_1]] {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> + %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8, nan_mode = "PROPAGATE"} : (tensor<4xi8>) -> tensor<4xi8> + %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8, nan_mode = "IGNORE"} : (tensor<4xi8>) -> tensor<4xi8> return %1 : tensor<4xi8> } @@ -1480,8 +1471,9 @@ func.func nested @do_not_fold_reciprocal_int() -> tensor<3x600x1200xi32> { // CHECK-LABEL: @canonicalize_select_to_clamp func.func @canonicalize_select_to_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 1.500000e+00 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: return %[[VAL_1]] : tensor<13x21x3xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 0x7F800000 : f32, min_val = 1.500000e+00 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: return [[VAR_0_]] : tensor<13x21x3xf32> %0 = "tosa.const"() <{value = dense<1.500000e+00> : tensor<13x21x3xf32>}>: () -> tensor<13x21x3xf32> %1 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> %2 = tosa.select %1, %arg0, %0: ( tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -1503,8 +1495,9 @@ func.func @canonicalize_select_to_clamp_not_splat(%arg0: tensor<4xi32>) -> tenso // CHECK-LABEL: @canonicalize_select_to_clamp_bf16 func.func @canonicalize_select_to_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> { -// CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = 0x7F80 : bf16, max_int = 9223372036854775807 : i64, min_fp = 1.500000e+00 : bf16, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> -// CHECK: return %[[VAL_1]] : tensor<13x21x3xbf16> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> { +// CHECK: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 0x7F80 : bf16, min_val = 1.500000e+00 : bf16} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> +// CHECK: return [[VAR_0_]] : tensor<13x21x3xbf16> %0 = "tosa.const"() <{value = dense<1.500000e+00> : tensor<13x21x3xbf16>}>: () -> tensor<13x21x3xbf16> %1 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xbf16>, tensor<13x21x3xbf16>) -> tensor<13x21x3xi1> %2 = tosa.select %1, %arg0, %0: ( tensor<13x21x3xi1>, tensor<13x21x3xbf16>, tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> @@ -1526,8 +1519,9 @@ func.func @canonicalize_select_to_clamp_ui64(%arg0: tensor<13x21x3xui64>) -> ten // CHECK-LABEL: @canonicalize_select_to_clamp_ui4 func.func @canonicalize_select_to_clamp_ui4(%arg0: tensor<13x21x3xui4>) -> tensor<13x21x3xui4> { -// CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 8 : i64} : (tensor<13x21x3xui4>) -> tensor<13x21x3xui4> -// CHECK: return %[[VAL_1]] : tensor<13x21x3xui4> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xui4>) -> tensor<13x21x3xui4> { +// CHECK: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 15 : ui4, min_val = 8 : ui4} : (tensor<13x21x3xui4>) -> tensor<13x21x3xui4> +// CHECK: return [[VAR_0_]] : tensor<13x21x3xui4> %0 = "tosa.const"() <{value = dense<8> : tensor<13x21x3xui4>}>: () -> tensor<13x21x3xui4> %1 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xui4>, tensor<13x21x3xui4>) -> tensor<13x21x3xi1> %2 = tosa.select %1, %arg0, %0: ( tensor<13x21x3xi1>, tensor<13x21x3xui4>, tensor<13x21x3xui4>) -> tensor<13x21x3xui4> @@ -1538,8 +1532,9 @@ func.func @canonicalize_select_to_clamp_ui4(%arg0: tensor<13x21x3xui4>) -> tenso // CHECK-LABEL: @canonicalize_select_to_clamp_i16_pat2 func.func @canonicalize_select_to_clamp_i16_pat2(%arg0: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> { -// CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = 0x7F800000 : f32, max_int = 3 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi16>) -> tensor<13x21x3xi16> -// CHECK: return %[[VAL_1]] : tensor<13x21x3xi16> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xi16>) -> tensor<13x21x3xi16> { +// CHECK: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 3 : i16, min_val = -32768 : i16} : (tensor<13x21x3xi16>) -> tensor<13x21x3xi16> +// CHECK: return [[VAR_0_]] : tensor<13x21x3xi16> %0 = "tosa.const"() <{value = dense<3> : tensor<13x21x3xi16>}>: () -> tensor<13x21x3xi16> %1 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<13x21x3xi1> %2 = tosa.select %1, %0, %arg0: ( tensor<13x21x3xi1>, tensor<13x21x3xi16>, tensor<13x21x3xi16>) -> tensor<13x21x3xi16> @@ -1549,8 +1544,9 @@ func.func @canonicalize_select_to_clamp_i16_pat2(%arg0: tensor<13x21x3xi16>) -> // CHECK-LABEL: @canonicalize_select_to_clamp_i8_neg func.func @canonicalize_select_to_clamp_i8_neg(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { -// CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = -42 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> -// CHECK: return %[[VAL_1]] : tensor<13x21x3xi8> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { +// CHECK: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 127 : i8, min_val = -42 : i8} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> +// CHECK: return [[VAR_0_]] : tensor<13x21x3xi8> %0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> %1 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1> %2 = tosa.select %1, %arg0, %0: ( tensor<13x21x3xi1>, tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8> @@ -1561,8 +1557,9 @@ func.func @canonicalize_select_to_clamp_i8_neg(%arg0: tensor<13x21x3xi8>) -> ten // CHECK-LABEL: @canonicalize_select_to_clamp_f64_pat2_neg func.func @canonicalize_select_to_clamp_f64_pat2_neg(%arg0: tensor<13x21x3xf64>) -> tensor<13x21x3xf64> { -// CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = -3.500000e+00 : f64, max_int = 9223372036854775807 : i64, min_fp = 0xFFF0000000000000 : f64, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xf64>) -> tensor<13x21x3xf64> -// CHECK: return %[[VAL_1]] : tensor<13x21x3xf64> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xf64>) -> tensor<13x21x3xf64> { +// CHECK: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = -3.500000e+00 : f64, min_val = 0xFFF0000000000000 : f64} : (tensor<13x21x3xf64>) -> tensor<13x21x3xf64> +// CHECK: return [[VAR_0_]] : tensor<13x21x3xf64> %0 = "tosa.const"() <{value = dense<-3.5> : tensor<13x21x3xf64>}>: () -> tensor<13x21x3xf64> %1 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xf64>, tensor<13x21x3xf64>) -> tensor<13x21x3xi1> %2 = tosa.select %1, %0, %arg0: ( tensor<13x21x3xi1>, tensor<13x21x3xf64>, tensor<13x21x3xf64>) -> tensor<13x21x3xf64> @@ -1572,8 +1569,9 @@ func.func @canonicalize_select_to_clamp_f64_pat2_neg(%arg0: tensor<13x21x3xf64>) // CHECK-LABEL: @canonicalize_select_lrelu_zero_pattern func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { -// CHECK: %[[VAL_1:.*]] = tosa.clamp %arg{{.*}} {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.000000e+00 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: return %[[VAL_1]] : tensor<13x21x3xf32> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { +// CHECK: [[VAR_0_:%.+]] = tosa.clamp [[PARAM_0_]] {max_val = 0x7F800000 : f32, min_val = 0.000000e+00 : f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: return [[VAR_0_]] : tensor<13x21x3xf32> %0 = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}>: () -> tensor<1x1x1xf32> %1 = tosa.mul %arg0, %0 {shift = 0 : i8}: (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xf32> %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<13x21x3xi1> @@ -1585,9 +1583,10 @@ func.func @canonicalize_select_lrelu_zero_pattern(%arg0: tensor<13x21x3xf32>) -> // CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat1 func.func @canonicalize_select_to_clamp_i64_and_i8_pat1(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { -// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8> -// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> -// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xi64>, [[PARAM_1_:%.+]]: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { +// CHECK: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8> +// CHECK: [[VAR_1_:%.+]] = tosa.clamp [[VAR_0_]] {max_val = 127 : i8, min_val = 42 : i8} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> +// CHECK: return [[VAR_1_]] : tensor<13x21x3xi8> %0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64> %1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1> @@ -1599,9 +1598,10 @@ func.func @canonicalize_select_to_clamp_i64_and_i8_pat1(%arg0: tensor<13x21x3xi6 // CHECK-LABEL: @canonicalize_select_to_clamp_i64_and_i8_pat2 func.func @canonicalize_select_to_clamp_i64_and_i8_pat2(%arg0: tensor<13x21x3xi64>, %arg1: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { -// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8> -// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> -// CHECK: return %[[VAL_2]] : tensor<13x21x3xi8> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xi64>, [[PARAM_1_:%.+]]: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> { +// CHECK: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<13x21x3xi64>) -> tensor<13x21x3xi8> +// CHECK: [[VAR_1_:%.+]] = tosa.clamp [[VAR_0_]] {max_val = -42 : i8, min_val = -128 : i8} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi8> +// CHECK: return [[VAR_1_]] : tensor<13x21x3xi8> %0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64> %1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi64>, tensor<13x21x3xi64>) -> tensor<13x21x3xi1> @@ -1613,9 +1613,10 @@ func.func @canonicalize_select_to_clamp_i64_and_i8_pat2(%arg0: tensor<13x21x3xi6 // CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat1 func.func @canonicalize_select_to_clamp_i8_and_i64_pat1(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { -// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64> -// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = 9223372036854775807 : i64, min_fp = 0xFF800000 : f32, min_int = 42 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> -// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xi8>, [[PARAM_1_:%.+]]: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { +// CHECK: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64> +// CHECK: [[VAR_1_:%.+]] = tosa.clamp [[VAR_0_]] {max_val = 9223372036854775807 : i64, min_val = 42 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> +// CHECK: return [[VAR_1_]] : tensor<13x21x3xi64> %0 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> %1 = "tosa.const"() <{value = dense<42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64> %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1> @@ -1627,9 +1628,10 @@ func.func @canonicalize_select_to_clamp_i8_and_i64_pat1(%arg0: tensor<13x21x3xi8 // CHECK-LABEL: @canonicalize_select_to_clamp_i8_and_i64_pat2 func.func @canonicalize_select_to_clamp_i8_and_i64_pat2(%arg0: tensor<13x21x3xi8>, %arg1: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { -// CHECK: %[[VAL_1:.*]] = tosa.cast %arg{{.*}} : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64> -// CHECK: %[[VAL_2:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0x7F800000 : f32, max_int = -42 : i64, min_fp = 0xFF800000 : f32, min_int = -9223372036854775808 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> -// CHECK: return %[[VAL_2]] : tensor<13x21x3xi64> +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<13x21x3xi8>, [[PARAM_1_:%.+]]: tensor<13x21x3xi64>) -> tensor<13x21x3xi64> { +// CHECK: [[VAR_0_:%.+]] = tosa.cast [[PARAM_0_]] : (tensor<13x21x3xi8>) -> tensor<13x21x3xi64> +// CHECK: [[VAR_1_:%.+]] = tosa.clamp [[VAR_0_]] {max_val = -42 : i64, min_val = -9223372036854775808 : i64} : (tensor<13x21x3xi64>) -> tensor<13x21x3xi64> +// CHECK: return [[VAR_1_]] : tensor<13x21x3xi64> %0 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi8>}>: () -> tensor<13x21x3xi8> %1 = "tosa.const"() <{value = dense<-42> : tensor<13x21x3xi64>}>: () -> tensor<13x21x3xi64> %2 = tosa.greater_equal %arg0, %0: (tensor<13x21x3xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi1> diff --git a/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir b/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir index 16af174e5f79e..d358e5f5403bf 100644 --- a/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir @@ -2,12 +2,12 @@ // CHECK-LABEL: @clamp_twice_is_single_clamp func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> { - // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64} {{.*}} loc(#[[FUSED:.*]]) + // CHECK: tosa.clamp %arg0 {max_val = 2 : i8, min_val = -2 : i8} {{.*}} loc(#[[FUSED:.*]]) // CHECK-DAG: #[[A:.*]] = loc("Clamp_A") // CHECK-DAG: #[[B:.*]] = loc("Clamp_B") // CHECK: #[[FUSED]] = loc(fused[#[[B]], #[[A]]]) - %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} : (tensor<4xi8>) -> tensor<4xi8> loc(#loc0) - %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} : (tensor<4xi8>) -> tensor<4xi8> loc(#loc1) + %0 = tosa.clamp %arg0 {max_val = 4 : i8, min_val = -2 : i8} : (tensor<4xi8>) -> tensor<4xi8> loc(#loc0) + %1 = tosa.clamp %0 {max_val = 2 : i8, min_val = -4 : i8} : (tensor<4xi8>) -> tensor<4xi8> loc(#loc1) return %1 : tensor<4xi8> } #loc0 = loc("Clamp_A") diff --git a/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir index 276e87405e695..58317535cd88c 100644 --- a/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir +++ b/mlir/test/Dialect/Tosa/constant-clamp-opt.mlir @@ -8,7 +8,7 @@ func.func @clamp_fold_integer() -> tensor<3xi16> { // CHECK-NOT: tosa.clamp // CHECK: return [[RES]] %0 = "tosa.const"() {value = dense<[-12, 0, 5]> : tensor<3xi16>} : () -> tensor<3xi16> - %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 1 : i64, min_fp = 0.0 : f32, min_int = -2 : i64} + %1 = "tosa.clamp"(%0) {max_val = 1 : i16, min_val = -2 : i16} : (tensor<3xi16>) -> tensor<3xi16> return %1 : tensor<3xi16> } @@ -19,21 +19,11 @@ func.func @clamp_fold_integer_equal_lower_upper() -> tensor<3xi8> { // CHECK-NOT: tosa.clamp // CHECK: return [[RES]] %0 = "tosa.const"() {value = dense<[2, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8> - %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 17 : i64, min_fp = 0.0 : f32, min_int = 17 : i64} + %1 = "tosa.clamp"(%0) {max_val = 17 : i8, min_val = 17 : i8} : (tensor<3xi8>) -> tensor<3xi8> return %1 : tensor<3xi8> } -// CHECK-LABEL: @clamp_fold_integer_maximum_larger_than_result_type -func.func @clamp_fold_integer_maximum_larger_than_result_type() -> tensor<3xi8> { - // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}9, 4, 4{{.*}}tensor<3xi8> - // CHECK-NOT: tosa.clamp - // CHECK: return [[RES]] - %0 = "tosa.const"() {value = dense<[9, 0, -5]> : tensor<3xi8>} : () -> tensor<3xi8> - %1 = "tosa.clamp"(%0) {max_fp = 0.00 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.0 : f32, min_int = 4 : i64} - : (tensor<3xi8>) -> tensor<3xi8> - return %1 : tensor<3xi8> -} // Float clamp @@ -43,7 +33,7 @@ func.func @clamp_fold_float() -> tensor<3xf16> { // CHECK-NOT: tosa.clamp // CHECK: return [[RES]] %0 = "tosa.const"() {value = dense<[-12.4, 0.9, 5.2]> : tensor<3xf16>} : () -> tensor<3xf16> - %1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + %1 = "tosa.clamp"(%0) {max_val = 1.00 : f16, min_val = -2.0 : f16} : (tensor<3xf16>) -> tensor<3xf16> return %1 : tensor<3xf16> } @@ -57,7 +47,7 @@ func.func @clamp_fold_float_infty_nan() -> tensor<5xf32> { dense<[0x7F800000, 0xFF800000, 0.0, -0.0, 0x7FC00000]> : tensor<5xf32> } : () -> tensor<5xf32> - %1 = "tosa.clamp"(%0) {max_fp = 1.00 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + %1 = "tosa.clamp"(%0) {max_val = 1.00 : f32, min_val = -2.0 : f32} : (tensor<5xf32>) -> tensor<5xf32> return %1 : tensor<5xf32> } @@ -71,21 +61,8 @@ func.func @clamp_fold_float_infinity_upper() -> tensor<5xf32> { dense<[0x7F800000, 0xFF800000, 9.0, -0.0, 0x7FC00000]> : tensor<5xf32> } : () -> tensor<5xf32> - %1 = "tosa.clamp"(%0) {max_fp = 0x7F800000 : f32, max_int = 1594 : i64, min_fp = -2.0 : f32, min_int = -17 : i64} + %1 = "tosa.clamp"(%0) {max_val = 0x7F800000 : f32, min_val = -2.0 : f32} : (tensor<5xf32>) -> tensor<5xf32> return %1 : tensor<5xf32> } -// CHECK-LABEL: @clamp_fold_float_maximum_larger_than_result_type -func.func @clamp_fold_float_maximum_larger_than_result_type() -> tensor<2xf16> { - // CHECK: [[RES:]] ={{.*}}tosa.const{{.*}}1.83{{[0-9]*}}e+01, -5.{{0*}}e-01 - // CHECK-NOT: tosa.clamp - // CHECK: return [[RES]] - %0 = "tosa.const"() {value = - dense<[18.32, -0.98747]> : - tensor<2xf16> - } : () -> tensor<2xf16> - %1 = "tosa.clamp"(%0) {max_fp = 3.4028234e+38 : f32, max_int = 1594 : i64, min_fp = -0.5 : f32, min_int = -17 : i64} - : (tensor<2xf16>) -> tensor<2xf16> - return %1 : tensor<2xf16> -} diff --git a/mlir/test/Dialect/Tosa/invalid.mlir b/mlir/test/Dialect/Tosa/invalid.mlir index 22b0861b4850c..3e8bc7cfde8da 100644 --- a/mlir/test/Dialect/Tosa/invalid.mlir +++ b/mlir/test/Dialect/Tosa/invalid.mlir @@ -855,7 +855,7 @@ func.func @test_unsupported_int64_data_type(%arg0: tensor<1x13x13x5xf32>) -> ten // CHECK-LABEL: test_mismatch_in_out_data_type_clamp func.func @test_mismatch_in_out_data_type_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf16> { // expected-error@+1 {{'tosa.clamp' op requires the same element type for all operands and results}} - %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16> + %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf16> return %0 : tensor<13x21x3xf16> } @@ -864,7 +864,7 @@ func.func @test_mismatch_in_out_data_type_clamp(%arg0: tensor<13x21x3xf32>) -> t // CHECK-LABEL: test_mismatch_in_out_shape_clamp func.func @test_mismatch_in_out_shape_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x1xf32> { // expected-error@+1 {{'tosa.clamp' op requires the same shape for all operands and results}} - %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x1xf32> + %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x1xf32> return %0 : tensor<13x21x1xf32> } diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir index 4333cf60cdbe7..bbbcb735e613e 100644 --- a/mlir/test/Dialect/Tosa/ops.mlir +++ b/mlir/test/Dialect/Tosa/ops.mlir @@ -193,42 +193,42 @@ func.func @test_transpose_conv2d_with_local_bound(%arg0: tensor<1x32x32x8xf32>, // ----- // CHECK-LABEL: clamp func.func @test_clamp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } // ----- // CHECK-LABEL: clamp_propagate func.func @test_clamp_propagate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32, nan_mode = "PROPAGATE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } // ----- // CHECK-LABEL: clamp_ignore func.func @test_clamp_ignore(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { - %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + %0 = tosa.clamp %arg0 {min_val = 0.0 : f32, max_val = 1.0: f32, nan_mode = "IGNORE"} : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> return %0 : tensor<13x21x3xf32> } // ----- // CHECK-LABEL: clamp_f16 func.func @test_clamp_f16(%arg0: tensor<13x21x3xf16>) -> tensor<13x21x3xf16> { - %0 = tosa.clamp %arg0 {min_fp = 0.0 : f16, max_fp = 1.0: f16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16> + %0 = tosa.clamp %arg0 {min_val = 0.0 : f16, max_val = 1.0: f16} : (tensor<13x21x3xf16>) -> tensor<13x21x3xf16> return %0 : tensor<13x21x3xf16> } // ----- // CHECK-LABEL: clamp_bf16 func.func @test_clamp_bf16(%arg0: tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> { - %0 = tosa.clamp %arg0 {min_fp = 0.0 : bf16, max_fp = 1.0: bf16, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> + %0 = tosa.clamp %arg0 {min_val = 0.0 : bf16, max_val = 1.0: bf16} : (tensor<13x21x3xbf16>) -> tensor<13x21x3xbf16> return %0 : tensor<13x21x3xbf16> } // ----- // CHECK-LABEL: clamp_quantized func.func @test_clamp_quantized(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { - %0 = tosa.clamp %arg0 {min_fp = 0.0 : f32, max_fp = 1.0: f32, min_int = 0 : i64, max_int = 1 : i64} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + %0 = tosa.clamp %arg0 {min_val = 0 : i8, max_val = 1 : i8} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> return %0 : tensor<13x21x3x!quant.uniform> } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 59b4f2bf84f7f..b8988c2b5728b 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -34,7 +34,7 @@ func.func @test_unary_f32(%arg0 : tensor<4xf32>) -> () { %1 = tosa.ceil %arg0 : (tensor<4xf32>) -> tensor<*xf32> // CHECK: tosa.clamp %arg0 {{.+}} : (tensor<4xf32>) -> tensor<4xf32> - %2 = tosa.clamp %arg0 { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32> + %2 = tosa.clamp %arg0 { min_val = 0.0 : f32, max_val = 10.0 : f32 } : (tensor<4xf32>) -> tensor<*xf32> // CHECK: tosa.exp %arg0 : (tensor<4xf32>) -> tensor<4xf32> %3 = tosa.exp %arg0 : (tensor<4xf32>) -> tensor<*xf32> @@ -82,7 +82,7 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () { %1 = tosa.bitwise_not %arg0 : (tensor<4xi32>) -> tensor<*xi32> // CHECK: tosa.clamp %arg0 {{.+}} : (tensor<4xi32>) -> tensor<4xi32> - %2 = tosa.clamp %arg0 { max_int = 10 : i64, min_int = 0 : i64, min_fp = 0.0 : f32, max_fp = 10.0 : f32 } : (tensor<4xi32>) -> tensor<*xi32> + %2 = tosa.clamp %arg0 { max_val = 10 : i32, min_val = 0 : i32} : (tensor<4xi32>) -> tensor<*xi32> // CHECK: tosa.clz %arg0 : (tensor<4xi32>) -> tensor<4xi32> %3 = tosa.clz %arg0 : (tensor<4xi32>) -> tensor<*xi32> diff --git a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir index 947335e45a9d9..e70f3644da646 100644 --- a/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-reduce-transposes.mlir @@ -22,7 +22,7 @@ func.func @test_transpose_tracks_to_nullifying_single_step(%arg0: tensor<1x2x3x4 func.func @test_transpose_tracks_to_nullifying_multi_unary_step(%arg0: tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> { %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32> - %clamp = tosa.clamp %0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> + %clamp = tosa.clamp %0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -41,7 +41,7 @@ func.func @test_transpose_tracks_to_nullifying_diverging_binary(%arg0: tensor<1x %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32> %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32> - %clamp = tosa.clamp %transpose0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> + %clamp = tosa.clamp %transpose0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -61,7 +61,7 @@ func.func @test_transpose_tracks_to_nullifying_diverging_binary_with_broadcastin %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32> %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x1x4xi32>, tensor<4xi32>) -> tensor<1x1x4x2xi32> - %clamp = tosa.clamp %transpose0 {max_fp = 1.0 : f32, min_fp = 0.0 : f32, max_int = 1 : i64, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> + %clamp = tosa.clamp %transpose0 {max_val = 1 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %abs = tosa.abs %transpose1 : (tensor<1x1x4x2xi32>) -> tensor<1x1x4x2xi32> %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x1x4x2xi32>) -> tensor<1x3x4x2xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -212,7 +212,7 @@ func.func @test_reshape_for_broadcast(%arg0: tensor<4x3x2xi32>) -> tensor<4x3x2x // CHECK-DAG: %[[VAL_19:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 64]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_20:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_19]] : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x1x1x64xf32> // CHECK-DAG: %[[VAL_21:.*]] = tosa.add %[[VAL_18]], %[[VAL_20]] : (tensor<1x112x112x64xf32>, tensor<1x1x1x64xf32>) -> tensor<1x112x112x64xf32> -// CHECK-DAG: %[[VAL_22:.*]] = tosa.clamp %[[VAL_21]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> +// CHECK-DAG: %[[VAL_22:.*]] = tosa.clamp %[[VAL_21]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32>, %74: tensor<1x112x112x64xf32>) -> tensor<1x112x112x64xf32> { %58 = tosa.const_shape {value = dense<[1, 64, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> %59 = "tosa.const"() <{value = dense_resource : tensor<64xf32>}> : () -> tensor<64xf32> @@ -233,12 +233,11 @@ func.func @test_resnet18_common_case(%arg0: tensor<64xf32>, %arg1: tensor<64xf32 %84 = tosa.mul %82, %83 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> %85 = tosa.reshape %59, %58 : (tensor<64xf32>, !tosa.shape<4>) -> tensor<1x64x1x1xf32> %86 = tosa.add %84, %85 : (tensor<1x64x112x112xf32>, tensor<1x64x1x1xf32>) -> tensor<1x64x112x112xf32> - %87 = tosa.clamp %86 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32> + %87 = tosa.clamp %86 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x112x112xf32>) -> tensor<1x64x112x112xf32> %88 = tosa.transpose %87, %63 : (tensor<1x64x112x112xf32>, tensor<4xi32>) -> tensor<1x112x112x64xf32> return %88 : tensor<1x112x112x64xf32> } - // ----- // CHECK-LABEL: @test_back_to_back_nullifiers @@ -280,7 +279,7 @@ func.func @test_back_to_back_nullifiers_different_transposes(%arg0: tensor<2x3x4 func.func @test_no_transform_if_outside_fan_in_cone(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) { %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> - %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32> + %clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> return %1, %clamp : tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32> @@ -296,7 +295,7 @@ func.func @test_two_different_downstream_converge_to_reshape_same_perms(%arg0: t %0 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> %shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> %1 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32> - %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32> + %2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32> %3 = tosa.transpose %1, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> %4 = tosa.transpose %2, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> return %3, %4 : tensor<1x1x64xf32>, tensor<1x1x64xf32> @@ -317,7 +316,7 @@ func.func @test_two_different_downstream_converge_to_reshape_different_perms(%ar %1 = "tosa.const"() <{value = dense<[0, 2, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> %shape = tosa.const_shape {value = dense<[1, 64, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> %2 = tosa.reshape %arg0, %shape : (tensor<64xf32>, !tosa.shape<3>) -> tensor<1x64x1xf32> - %3 = tosa.clamp %2 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32> + %3 = tosa.clamp %2 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<1x64x1xf32>) -> tensor<1x64x1xf32> %4 = tosa.transpose %2, %1 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<1x1x64xf32> %5 = tosa.transpose %3, %0 : (tensor<1x64x1xf32>, tensor<3xi32>) -> tensor<64x1x1xf32> return %4, %5 : tensor<1x1x64xf32>, tensor<64x1x1xf32> @@ -335,7 +334,7 @@ func.func @test_two_different_downstream_converge_to_reshape_different_perms(%ar // CHECK: return %[[RES1]], %[[RES2]] func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<3x2xf32>) -> (tensor<2x3xf32>, tensor<3x2xf32>) { %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x2xf32>) -> tensor<3x2xf32> + %2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<3x2xf32>) -> tensor<3x2xf32> %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32> %4 = tosa.add %arg1, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> return %3, %4: tensor<2x3xf32>, tensor<3x2xf32> @@ -352,7 +351,7 @@ func.func @test_outside_perms_usage_of_fan_in(%arg0: tensor<2x3xf32>, %arg1: ten func.func @test_use_present_in_another_valid_perms_fan_in(%arg0: tensor<2x3xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x3xf32>, tensor<2x3xf32>) { %0 = "tosa.const"() <{value = dense<[1, 0]> : tensor<2xi32>}> : () -> tensor<2xi32> %1 = tosa.transpose %arg0, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> - %2 = tosa.clamp %1 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x2xf32>) -> tensor<3x2xf32> + %2 = tosa.clamp %1 {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} : (tensor<3x2xf32>) -> tensor<3x2xf32> %3 = tosa.transpose %2, %0 : (tensor<3x2xf32>, tensor<2xi32>) -> tensor<2x3xf32> %4 = tosa.transpose %arg1, %0 : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> %5 = tosa.add %4, %2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> @@ -389,7 +388,7 @@ func.func @test_two_same_perms_fan_in_but_one_doesnt_convert_dependents(%arg0: t func.func @test_direct_use_in_other_transpose_with_same_perms(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) { %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> - %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32> + %clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> @@ -413,12 +412,12 @@ func.func @test_const_transpose() -> tensor<2x3xi32> { // CHECK-LABEL: @test_transpose_tracks_to_const_single_step // CHECK: %[[NEW_CONST:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32> -// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> +// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> // CHECK-NOT: tosa.transpose // CHECK: return %[[NEW_CLAMP]] func.func @test_transpose_tracks_to_const_single_step() -> tensor<1x2x3x4xi32> { %0 = "tosa.const"() {value = dense<0> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32> - %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> + %clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> %1 = tosa.transpose %clamp, %perms1 : (tensor<1x3x4x2xi32>, tensor<4xi32>) -> tensor<1x2x3x4xi32> return %1 : tensor<1x2x3x4xi32> @@ -428,14 +427,14 @@ func.func @test_transpose_tracks_to_const_single_step() -> tensor<1x2x3x4xi32> { // CHECK-LABEL: @test_static_unary_path_to_const // CHECK: %[[NEW_CONST:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32> -// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> +// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %[[NEW_CONST]] {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> // CHECK: %[[NEW_ABS:.*]] = tosa.abs %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> // CHECK: %[[NEW_NOT:.*]] = tosa.bitwise_not %[[NEW_ABS]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> // CHECK: return %[[NEW_NOT]] func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> { %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = "tosa.const"() {value = dense<1> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32> - %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> + %clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %abs = tosa.abs %clamp : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %bitwise_not = tosa.bitwise_not %abs : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -449,7 +448,7 @@ func.func @test_static_unary_path_to_const() -> tensor<1x2x3x4xi32> { // CHECK: %[[NEW_CONST:.*]] = "tosa.const"() // CHECK-SAME{LITERAL}: dense<[[[[1, 3, 5, 7], [9, 11, 13, 15], [17, 19, 21, 23]], [[2, 4, 6, 8], [10, 12, 14, 16], [18, 20, 22, 24]]]]> // CHECK: tensor<1x2x3x4xi32>}> : () -> tensor<1x2x3x4xi32> -// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %arg0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> +// CHECK: %[[NEW_CLAMP:.*]] = tosa.clamp %arg0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> // CHECK: %[[NEW_ABS:.*]] = tosa.abs %[[NEW_CONST]] : (tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> // CHECK: %[[NEW_ADD:.*]] = tosa.add %[[NEW_ABS]], %[[NEW_CLAMP]] : (tensor<1x2x3x4xi32>, tensor<1x2x3x4xi32>) -> tensor<1x2x3x4xi32> // CHECK: return %[[NEW_ADD]] @@ -459,7 +458,7 @@ func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor< %const = "tosa.const"() {value = dense<[[[[1, 2], [3, 4], [5, 6], [7, 8]], [[9, 10], [11, 12], [13, 14], [15, 16]], [[17, 18], [19, 20], [21, 22], [23, 24]]]]> : tensor<1x3x4x2xi32>} : () -> tensor<1x3x4x2xi32> - %clamp = tosa.clamp %transpose0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> + %clamp = tosa.clamp %transpose0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %abs = tosa.abs %const : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %add = tosa.add %abs, %clamp : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -475,7 +474,7 @@ func.func @test_static_diverges_to_non_splat_const_and_nullifying(%arg0: tensor< func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) { %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> - %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32> + %clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> %2 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> @@ -495,7 +494,7 @@ func.func @test_multi_downstream_both_nullify(%arg0: tensor<3x3x3x3xi32>) -> (te func.func @test_multi_downstream_one_nullifies_upstream_other_does_not(%arg0: tensor<3x3x3x3xi32>) -> (tensor<3x3x3x3xi32>, tensor<3x3x3x3xi32>) { %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = tosa.transpose %arg0, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> - %clamp = tosa.clamp %0 {max_fp = 3.40282347E+38 : f32, max_int = 2147483647 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32> + %clamp = tosa.clamp %0 {max_val = 2147483647 : i32, min_val = 0 : i32} : (tensor<3x3x3x3xi32>) -> tensor<3x3x3x3xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> %1 = tosa.transpose %clamp, %perms1 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> %2 = tosa.transpose %clamp, %perms0 : (tensor<3x3x3x3xi32>, tensor<4xi32>) -> tensor<3x3x3x3xi32> @@ -536,7 +535,7 @@ func.func @test_transpose_tracks_to_nullifying_diverging_binary_unknown_dim_repl %perms0 = "tosa.const"() {value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x?x3x4xi32>, tensor<4xi32>) -> tensor %transpose1 = tosa.transpose %arg1, %perms0 : (tensor<1x2x?x4xi32>, tensor<4xi32>) -> tensor<1x?x?x2xi32> - %clamp = tosa.clamp %transpose0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor) -> tensor + %clamp = tosa.clamp %transpose0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor) -> tensor %abs = tosa.abs %transpose1 : (tensor<1x?x?x2xi32>) -> tensor<1x?x?x2xi32> %add = tosa.add %clamp, %abs : (tensor, tensor<1x?x?x2xi32>) -> tensor<1x3x4x2xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> @@ -571,7 +570,7 @@ func.func @test_unimplemented_non_const_perms(%perms: tensor<2xi32>) -> tensor) -> tensor<1x2x4x3xi32> { %perms0 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x4x3x2xi32> - %clamp = tosa.clamp %0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<1x4x3x2xi32>) -> tensor<1x4x3x2xi32> + %clamp = tosa.clamp %0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor<1x4x3x2xi32>) -> tensor<1x4x3x2xi32> %perms1 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> %1 = tosa.transpose %clamp, %perms1 : (tensor<1x4x3x2xi32>, tensor<4xi32>) -> tensor<1x2x4x3xi32> return %1 : tensor<1x2x4x3xi32> @@ -653,7 +652,7 @@ func.func @test_unimplemented_static_diverges_to_one_nullifying_one_non_nullifyi %perms1 = "tosa.const"() {value = dense<[0, 3, 2, 1]> : tensor<4xi32>} : () -> tensor<4xi32> %transpose0 = tosa.transpose %arg0, %perms0 : (tensor<1x2x3x4xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32> %transpose1 = tosa.transpose %arg1, %perms1 : (tensor<1x2x4x3xi32>, tensor<4xi32>) -> tensor<1x3x4x2xi32> - %clamp = tosa.clamp %transpose0 {min_int = 0 : i64, max_int = 1 : i64, min_fp = 0.0 : f64, max_fp = 1.0 : f64} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> + %clamp = tosa.clamp %transpose0 {min_val = 0 : i32, max_val = 1 : i32} : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %abs = tosa.abs %transpose1 : (tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %add = tosa.add %clamp, %abs : (tensor<1x3x4x2xi32>, tensor<1x3x4x2xi32>) -> tensor<1x3x4x2xi32> %perms2 = "tosa.const"() {value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32>