Skip to content

[AutoBump] Merge with fixes ab93bd69 (Feb 11) (43) #587

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: bump_to_9387fd96
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,8 @@ def Tosa_ClampOp : Tosa_ElementwiseUnaryOp<"clamp"> {

let arguments = (ins
Tosa_Tensor:$input,
I64Attr:$min_int,
I64Attr:$max_int,
Tosa_FloatAttr:$min_fp,
Tosa_FloatAttr:$max_fp,
Tosa_IntOrFloatAttr:$min_val,
Tosa_IntOrFloatAttr:$max_val,
DefaultValuedAttr<Tosa_NanPropagationAttr, "\"PROPAGATE\"">:$nan_mode
);

Expand Down
8 changes: 8 additions & 0 deletions mlir/include/mlir/Dialect/Tosa/IR/TosaTypesBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,14 @@ def Tosa_FloatAttr : Attr<CPred<"::llvm::isa<::mlir::FloatAttr>($_self)">,
let returnType = [{ ::mlir::APFloat }];
}

def Tosa_IntegerAttr : Attr<CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">,
"arbitrary integer attribute"> {
let storageType = [{ ::mlir::IntegerAttr }];
let returnType = [{ ::llvm::APInt }];
}

def Tosa_IntOrFloatAttr : AnyAttrOf<[Tosa_IntegerAttr, Tosa_FloatAttr]>;

//===----------------------------------------------------------------------===//
// Iterable attributes.
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 4 additions & 4 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,8 +407,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
// tosa::ClampOp
if (isa<tosa::ClampOp>(op) && isa<FloatType>(elementTy)) {
bool losesInfo = false;
APFloat minApf = cast<FloatAttr>(op->getAttr("min_fp")).getValue();
APFloat maxApf = cast<FloatAttr>(op->getAttr("max_fp")).getValue();
APFloat minApf = cast<FloatAttr>(op->getAttr("min_val")).getValue();
APFloat maxApf = cast<FloatAttr>(op->getAttr("max_val")).getValue();
minApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
APFloat::rmNearestTiesToEven, &losesInfo);
maxApf.convert(cast<FloatType>(elementTy).getFloatSemantics(),
Expand All @@ -423,9 +423,9 @@ static Value createLinalgBodyCalculationForElementwiseOp(
if (isa<tosa::ClampOp>(op) && isa<IntegerType>(elementTy)) {
auto intTy = cast<IntegerType>(elementTy);
int64_t min =
cast<IntegerAttr>(op->getAttr("min_int")).getValue().getSExtValue();
cast<IntegerAttr>(op->getAttr("min_val")).getValue().getSExtValue();
int64_t max =
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
cast<IntegerAttr>(op->getAttr("max_val")).getValue().getSExtValue();

int64_t minRepresentable = std::numeric_limits<int64_t>::min();
int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
Expand Down
237 changes: 155 additions & 82 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,9 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
DenseElementsAttr onFalseAttr;
DenseElementsAttr onTrueAttr;

const Type resultElemTy = op.getType().getElementType();
const bool resultElemTyIsUnsignedInteger = resultElemTy.isUnsignedInteger();

// Case one:
// %0 = tosa.greater_equal(input, cmp)
// %1 = tosa.select(%0, input, cmp)
Expand All @@ -328,10 +331,8 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
}

const auto inputElementType = geqIn2Attr.getElementType();
int64_t clampIntMin = std::numeric_limits<int64_t>::min();
int64_t clampIntMax = std::numeric_limits<int64_t>::max();
FloatAttr clampFloatMin;
FloatAttr clampFloatMax;
Attribute clampMin;
Attribute clampMax;
if (auto integerType = dyn_cast<IntegerType>(inputElementType)) {
int64_t splatValue;
if (integerType.isUnsigned()) {
Expand All @@ -343,26 +344,33 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
} else {
splatValue = geqIn2Attr.getSplatValue<APInt>().getSExtValue();
}
clampFloatMin =
rewriter.getF32FloatAttr(-std::numeric_limits<float>::infinity());
clampFloatMax =
rewriter.getF32FloatAttr(std::numeric_limits<float>::infinity());
if (isCaseOne) {
clampIntMin = splatValue;
clampMin = rewriter.getIntegerAttr(resultElemTy, splatValue);
clampMax = rewriter.getIntegerAttr(
resultElemTy,
resultElemTyIsUnsignedInteger
? APInt::getMaxValue(resultElemTy.getIntOrFloatBitWidth())
: APInt::getSignedMaxValue(
resultElemTy.getIntOrFloatBitWidth()));
} else {
clampIntMax = splatValue;
clampMax = rewriter.getIntegerAttr(resultElemTy, splatValue);
clampMin = rewriter.getIntegerAttr(
resultElemTy,
resultElemTyIsUnsignedInteger
? APInt::getMinValue(resultElemTy.getIntOrFloatBitWidth())
: APInt::getSignedMinValue(
resultElemTy.getIntOrFloatBitWidth()));
}
} else if (isa<FloatType>(inputElementType)) {
auto splatValue = geqIn2Attr.getSplatValue<APFloat>();
if (isCaseOne) {
clampFloatMin = rewriter.getFloatAttr(inputElementType, splatValue);
clampFloatMax = rewriter.getFloatAttr(
inputElementType,
APFloat::getInf(splatValue.getSemantics(), false));
clampMin = rewriter.getFloatAttr(resultElemTy, splatValue);
clampMax = rewriter.getFloatAttr(
resultElemTy, APFloat::getInf(splatValue.getSemantics(), false));
} else {
clampFloatMin = rewriter.getFloatAttr(
inputElementType, APFloat::getInf(splatValue.getSemantics(), true));
clampFloatMax = rewriter.getFloatAttr(inputElementType, splatValue);
clampMin = rewriter.getFloatAttr(
resultElemTy, APFloat::getInf(splatValue.getSemantics(), true));
clampMax = rewriter.getFloatAttr(resultElemTy, splatValue);
}
}

Expand All @@ -380,9 +388,8 @@ struct SelectToClampOptimization : public OpRewritePattern<tosa::SelectOp> {
input);
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), input, rewriter.getI64IntegerAttr(clampIntMin),
rewriter.getI64IntegerAttr(clampIntMax), clampFloatMin, clampFloatMax);
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), input,
clampMin, clampMax);

return success();
}
Expand Down Expand Up @@ -606,10 +613,12 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {

if (isa<FloatType>(inputElementType)) {
// Unlike integer types, floating point types can represent infinity.
auto minClamp = op.getMinFp();
auto maxClamp = op.getMaxFp();
bool isMin = minClamp.isInfinity() && minClamp.isNegative();
bool isMax = maxClamp.isInfinity() && !maxClamp.isNegative();
auto minClamp =
llvm::cast<mlir::FloatAttr>(op.getMinValAttr()).getValue();
auto maxClamp =
llvm::cast<mlir::FloatAttr>(op.getMaxValAttr()).getValue();
bool isMin = minClamp.isNegInfinity();
bool isMax = maxClamp.isInfinity();

if (isMin && isMax) {
rewriter.replaceOp(op, input);
Expand All @@ -619,8 +628,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}

if (inputElementType.isUnsignedInteger()) {
int64_t minClamp = op.getMinInt();
int64_t maxClamp = op.getMaxInt();
int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getUInt();
int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getUInt();

int64_t intMin =
APInt::getMinValue(inputElementType.getIntOrFloatBitWidth())
Expand All @@ -637,8 +648,10 @@ struct ClampIsNoOp : public OpRewritePattern<tosa::ClampOp> {
}

if (llvm::isa<IntegerType>(inputElementType)) {
int64_t minClamp = op.getMinInt();
int64_t maxClamp = op.getMaxInt();
int64_t minClamp =
llvm::cast<mlir::IntegerAttr>(op.getMinValAttr()).getInt();
int64_t maxClamp =
llvm::cast<mlir::IntegerAttr>(op.getMaxValAttr()).getInt();

int64_t intMin =
APInt::getSignedMinValue(inputElementType.getIntOrFloatBitWidth())
Expand Down Expand Up @@ -693,9 +706,10 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {

LogicalResult matchAndRewrite(tosa::ClampOp op,
PatternRewriter &rewriter) const override {
Value input = op.getInput();

// Check the input to the CLAMP op is itself a CLAMP.
auto clampOp =
dyn_cast_if_present<tosa::ClampOp>(op.getInput().getDefiningOp());
auto clampOp = dyn_cast_if_present<tosa::ClampOp>(input.getDefiningOp());
if (!clampOp)
return failure();

Expand All @@ -705,34 +719,87 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
if (opNanMode == "IGNORE" && clampNanMode == "PROPAGATE")
return failure();

// Check we have intersecting ranges.
const auto opMinInt = op.getMinInt();
const auto opMaxInt = op.getMaxInt();
const auto clampOpMinInt = clampOp.getMinInt();
const auto clampOpMaxInt = clampOp.getMaxInt();
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt, clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();
auto maxValAttr = op.getMaxValAttr();
auto minValAttr = op.getMinValAttr();
auto clampOpMaxValAttr = clampOp.getMaxValAttr();
auto clampOpMinValAttr = clampOp.getMinValAttr();

const auto opMinFloat = op.getMinFp();
const auto opMaxFloat = op.getMaxFp();
const auto clampOpMinFloat = clampOp.getMinFp();
const auto clampOpMaxFloat = clampOp.getMaxFp();
ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat, clampOpMaxFloat);
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
return failure();
auto inputEType = llvm::cast<ShapedType>(input.getType()).getElementType();
if (auto quantType =
llvm::dyn_cast<mlir::quant::UniformQuantizedType>(inputEType)) {
inputEType = quantType.getStorageType();
}

Attribute newMinValAttr, newMaxValAttr;
if (mlir::isa<FloatType>(inputEType)) {
auto floatMaxValAttr = cast<mlir::FloatAttr>(maxValAttr);
auto floatMinValAttr = cast<mlir::FloatAttr>(minValAttr);
auto clampOpFloatMaxValAttr = cast<mlir::FloatAttr>(clampOpMaxValAttr);
auto clampOpFloatMinValAttr = cast<mlir::FloatAttr>(clampOpMinValAttr);

// Check we have intersecting ranges.
const auto opMinFloat = floatMinValAttr.getValue();
const auto opMaxFloat = floatMaxValAttr.getValue();
const auto clampOpMinFloat = clampOpFloatMinValAttr.getValue();
const auto clampOpMaxFloat = clampOpFloatMaxValAttr.getValue();
ClampRange<APFloat> opRangeFloatRange(opMinFloat, opMaxFloat);
ClampRange<APFloat> clampRangeFloatRange(clampOpMinFloat,
clampOpMaxFloat);
if (!opRangeFloatRange.intersects(clampRangeFloatRange))
return failure();

// Run the transformation.
auto newMinVal = std::max(opMinFloat, clampOpMinFloat);
auto newMaxVal = std::min(opMaxFloat, clampOpMaxFloat);
newMinValAttr = rewriter.getFloatAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getFloatAttr(inputEType, newMaxVal);
} else {
assert(mlir::isa<IntegerType>(inputEType));
auto intMaxValAttr = cast<mlir::IntegerAttr>(maxValAttr);
auto intMinValAttr = cast<mlir::IntegerAttr>(minValAttr);
auto clampOpIntMaxValAttr = cast<mlir::IntegerAttr>(clampOpMaxValAttr);
auto clampOpIntMinValAttr = cast<mlir::IntegerAttr>(clampOpMinValAttr);

if (inputEType.isUnsignedInteger()) {
// Check we have intersecting ranges.
const auto opMinInt = intMinValAttr.getUInt();
const auto opMaxInt = intMaxValAttr.getUInt();
const auto clampOpMinInt = clampOpIntMinValAttr.getUInt();
const auto clampOpMaxInt = clampOpIntMaxValAttr.getUInt();
ClampRange<std::uint64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::uint64_t> clampRangeIntRange(clampOpMinInt,
clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();

// Run the transformation.
auto newMinVal = std::max(opMinInt, clampOpMinInt);
auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
} else {
// Check we have intersecting ranges.
const auto opMinInt = intMinValAttr.getInt();
const auto opMaxInt = intMaxValAttr.getInt();
const auto clampOpMinInt = clampOpIntMinValAttr.getInt();
const auto clampOpMaxInt = clampOpIntMaxValAttr.getInt();
ClampRange<std::int64_t> opRangeIntRange(opMinInt, opMaxInt);
ClampRange<std::int64_t> clampRangeIntRange(clampOpMinInt,
clampOpMaxInt);
if (!opRangeIntRange.intersects(clampRangeIntRange))
return failure();

// Run the transformation.
auto newMinVal = std::max(opMinInt, clampOpMinInt);
auto newMaxVal = std::min(opMaxInt, clampOpMaxInt);
newMinValAttr = rewriter.getIntegerAttr(inputEType, newMinVal);
newMaxValAttr = rewriter.getIntegerAttr(inputEType, newMaxVal);
}
}

// Run the transformation.
const auto minFp = std::max(opMinFloat, clampOpMinFloat).convertToFloat();
const auto maxFp = std::min(opMaxFloat, clampOpMaxFloat).convertToFloat();
const auto minInt = std::max(opMinInt, clampOpMinInt);
const auto maxInt = std::min(opMaxInt, clampOpMaxInt);
rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, {op->getLoc(), clampOp->getLoc()}, op.getType(), clampOp.getInput(),
rewriter.getI64IntegerAttr(minInt), rewriter.getI64IntegerAttr(maxInt),
rewriter.getF32FloatAttr(minFp), rewriter.getF32FloatAttr(maxFp),
newMinValAttr, newMaxValAttr,
rewriter.getStringAttr((opNanMode != clampNanMode) ? "IGNORE"
: opNanMode));
return success();
Expand Down Expand Up @@ -973,25 +1040,28 @@ struct MinToClampOptimization : public OpRewritePattern<tosa::MinimumOp> {
Value input = op.getInput1();
auto elementTy = llvm::cast<ShapedType>(input.getType()).getElementType();

int64_t minInt = std::numeric_limits<int32_t>::min();
float minFp = std::numeric_limits<float>::lowest();

int64_t maxInt;
float maxFp;
if (isa<FloatType>(elementTy)) {
Attribute minAttr;
Attribute maxAttr;
if (auto floatTy = dyn_cast<FloatType>(elementTy)) {
auto constMin = constant.getSplatValue<llvm::APFloat>();
maxFp = constMin.convertToFloat();
maxInt = constMin.convertToFloat();
maxAttr = rewriter.getFloatAttr(floatTy, constMin);
minAttr = rewriter.getFloatAttr(
floatTy, APFloat::getInf(constMin.getSemantics(), /*Negative=*/true));
} else if (auto intTy = cast<IntegerType>(elementTy);
intTy.isUnsignedInteger()) {
auto constMin = constant.getSplatValue<llvm::APInt>();
maxAttr = rewriter.getIntegerAttr(intTy, constMin);
minAttr =
rewriter.getIntegerAttr(intTy, APInt::getMinValue(intTy.getWidth()));
} else {
auto constMin = constant.getSplatValue<llvm::APInt>();
maxFp = constMin.getSExtValue();
maxInt = constMin.getSExtValue();
maxAttr = rewriter.getIntegerAttr(intTy, constMin);
minAttr = rewriter.getIntegerAttr(
intTy, APInt::getSignedMinValue(intTy.getWidth()));
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), input, rewriter.getI64IntegerAttr(minInt),
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
rewriter.getF32FloatAttr(maxFp));
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), input, minAttr,
maxAttr);

return success();
}
Expand All @@ -1016,25 +1086,28 @@ struct MaxToClampOptimization : public OpRewritePattern<tosa::MaximumOp> {
Value input = op.getInput1();
auto elementTy = llvm::cast<ShapedType>(input.getType()).getElementType();

int64_t maxInt = std::numeric_limits<int64_t>::max();
float maxFp = std::numeric_limits<float>::max();

int64_t minInt;
float minFp;
if (isa<FloatType>(elementTy)) {
Attribute minAttr;
Attribute maxAttr;
if (auto floatTy = dyn_cast<FloatType>(elementTy)) {
auto constMax = constant.getSplatValue<llvm::APFloat>();
minFp = constMax.convertToFloat();
minInt = constMax.convertToFloat();
minAttr = rewriter.getFloatAttr(floatTy, constMax);
maxAttr = rewriter.getFloatAttr(floatTy,
APFloat::getInf(constMax.getSemantics()));
} else if (auto intTy = cast<IntegerType>(elementTy);
intTy.isUnsignedInteger()) {
auto constMax = constant.getSplatValue<llvm::APInt>();
minAttr = rewriter.getIntegerAttr(intTy, constMax);
maxAttr =
rewriter.getIntegerAttr(intTy, APInt::getMaxValue(intTy.getWidth()));
} else {
auto constMax = constant.getSplatValue<llvm::APInt>();
minFp = constMax.getSExtValue();
minInt = constMax.getSExtValue();
minAttr = rewriter.getIntegerAttr(intTy, constMax);
maxAttr = rewriter.getIntegerAttr(
intTy, APInt::getSignedMaxValue(intTy.getWidth()));
}

rewriter.replaceOpWithNewOp<tosa::ClampOp>(
op, op.getType(), input, rewriter.getI64IntegerAttr(minInt),
rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
rewriter.getF32FloatAttr(maxFp));
rewriter.replaceOpWithNewOp<tosa::ClampOp>(op, op.getType(), input, minAttr,
maxAttr);

return success();
}
Expand Down
Loading