Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PDLL] Emit warning for inexact result of floating point binary arithme… #169

Merged
merged 2 commits into from
Apr 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
@@ -102,18 +102,12 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
} else {
llvm::llvm_unreachable_internal(
"encountered an unsupported unary operator");
return failure();
}
return success();
}

if (auto operandFloatAttr = dyn_cast_or_null<FloatAttr>(operandAttr)) {
// auto floatType = operandFloatAttr.getType();

if constexpr (T == UnaryOpKind::exp2) {
// auto maxVal = APFloat::getLargest(llvm::APFloat::IEEEhalf());
// auto minVal = APFloat::getSmallest(llvm::APFloat::IEEEhalf());

auto type = operandFloatAttr.getType();

return TypeSwitch<Type, LogicalResult>(type)
@@ -166,9 +160,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,

if (auto lhsIntAttr = dyn_cast_or_null<IntegerAttr>(lhsAttr)) {
auto rhsIntAttr = dyn_cast_or_null<IntegerAttr>(rhsAttr);
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType()) {
if (!rhsIntAttr || lhsIntAttr.getType() != rhsIntAttr.getType())
return failure();
}

auto integerType = lhsIntAttr.getType();
APInt resultAPInt;
@@ -211,7 +204,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
resultAPInt = lhsIntAttr.getValue().srem(rhsIntAttr.getValue());
}
} else {
assert(false && "Unsupported binary operator");
llvm::llvm_unreachable_internal(
"encounter an unsupported binary operator.");
}

if (isOverflow)
@@ -223,9 +217,8 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,

if (auto lhsFloatAttr = dyn_cast_or_null<FloatAttr>(lhsAttr)) {
auto rhsFloatAttr = dyn_cast_or_null<FloatAttr>(rhsAttr);
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType()) {
if (!rhsFloatAttr || lhsFloatAttr.getType() != rhsFloatAttr.getType())
return failure();
}

APFloat lhsVal = lhsFloatAttr.getValue();
APFloat rhsVal = rhsFloatAttr.getValue();
@@ -248,13 +241,19 @@ LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
} else if constexpr (T == BinaryOpKind::mod) {
operationStatus = resultVal.mod(rhsVal);
} else {
assert(false && "Unsupported binary operator");
llvm::llvm_unreachable_internal(
"encounter an unsupported binary operator.");
}

if (operationStatus != APFloat::opOK) {
return failure();
}
if (operationStatus != APFloat::opInexact)
return failure();

emitWarning(rewriter.getUnknownLoc())
<< "Binary arithmetic operation between " << lhsVal.convertToFloat()
<< " and " << rhsVal.convertToFloat()
<< " produced an inexact result";
}
results.push_back(rewriter.getFloatAttr(floatType, resultVal));
return success();
}
38 changes: 18 additions & 20 deletions mlir/unittests/Dialect/PDL/BuiltinTest.cpp
Original file line number Diff line number Diff line change
@@ -253,10 +253,19 @@ TEST_F(BuiltinTest, div) {
"Divide by zero?");
}

auto smallF16 = rewriter.getF16FloatAttr(0.0001);
auto BF16Type = rewriter.getBF16Type();
auto oneBF16 = rewriter.getFloatAttr(BF16Type, 1.0);
auto nineBF16 = rewriter.getFloatAttr(BF16Type, 9.0);

// float: inexact result
// return success(), but warning is emitted.
{
TestPDLResultList results(1);
EXPECT_TRUE(
builtin::div(rewriter, results, {oneBF16, nineBF16}).succeeded());
}

auto twoF16 = rewriter.getF16FloatAttr(2.0);
auto maxValF16 = rewriter.getF16FloatAttr(
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
auto zeroF16 = rewriter.getF16FloatAttr(0.0);
auto negzeroF16 = rewriter.getF16FloatAttr(-0.0);

@@ -272,13 +281,6 @@ TEST_F(BuiltinTest, div) {
EXPECT_TRUE(builtin::div(rewriter, results, {twoF16, negzeroF16}).failed());
}

// float: overflow
{
TestPDLResultList results(1);
EXPECT_TRUE(
builtin::div(rewriter, results, {maxValF16, smallF16}).failed());
}

// float: correctness
{
TestPDLResultList results(1);
@@ -456,19 +458,17 @@ TEST_F(BuiltinTest, add) {
EXPECT_TRUE(builtin::add(rewriter, results, {oneI16, oneI32}).failed());
}

auto oneF16 = rewriter.getF16FloatAttr(1.0);
auto oneF32 = rewriter.getF32FloatAttr(1.0);
auto zeroF32 = rewriter.getF32FloatAttr(0.0);
auto negzeroF32 = rewriter.getF32FloatAttr(-0.0);
auto zeroF64 = rewriter.getF64FloatAttr(0.0);

auto maxValF16 = rewriter.getF16FloatAttr(
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
auto overflowF16 = rewriter.getF16FloatAttr(32768);

// float: overflow
{
TestPDLResultList results(1);
EXPECT_TRUE(builtin::add(rewriter, results, {oneF16, maxValF16}).failed());
EXPECT_TRUE(
builtin::add(rewriter, results, {overflowF16, overflowF16}).failed());
}

// float: correctness
@@ -553,19 +553,17 @@ TEST_F(BuiltinTest, sub) {
EXPECT_TRUE(builtin::sub(rewriter, results, {oneI16, oneI32}).failed());
}

auto oneF16 = rewriter.getF16FloatAttr(1.0);
auto oneF16 = rewriter.getF16FloatAttr(100.0);
auto oneF32 = rewriter.getF32FloatAttr(1.0);
auto zeroF32 = rewriter.getF32FloatAttr(0.0);
auto negzeroF32 = rewriter.getF32FloatAttr(-0.0);
auto zeroF64 = rewriter.getF64FloatAttr(0.0);

auto maxValF16 = rewriter.getF16FloatAttr(
llvm::APFloat::getLargest(llvm::APFloat::IEEEhalf()).convertToFloat());
auto minValF16 = rewriter.getF16FloatAttr(-65504);

// float: overflow
{
TestPDLResultList results(1);
EXPECT_TRUE(builtin::sub(rewriter, results, {maxValF16, oneF16}).failed());
EXPECT_TRUE(builtin::sub(rewriter, results, {oneF16, minValF16}).failed());
}

// float: correctness