diff --git a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h index d5a3e5a091055..8ee8e05c32495 100644 --- a/mlir/include/mlir/Dialect/PDL/IR/Builtins.h +++ b/mlir/include/mlir/Dialect/PDL/IR/Builtins.h @@ -28,15 +28,16 @@ void registerBuiltins(PDLPatternModule &pdlPattern); namespace builtin { enum class BinaryOpKind { add, - sub, - mul, div, mod, + mul, + sub, }; enum class UnaryOpKind { - log2, + abs, exp2, + log2, }; LogicalResult createDictionaryAttr(PatternRewriter &rewriter, @@ -48,9 +49,6 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter, Attribute createArrayAttr(PatternRewriter &rewriter); Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr, Attribute element); -template -LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results, - llvm::ArrayRef args); LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results, llvm::ArrayRef args); LogicalResult div(PatternRewriter &rewriter, PDLResultList &results, @@ -65,10 +63,8 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results, llvm::ArrayRef args); LogicalResult exp2(PatternRewriter &rewriter, PDLResultList &results, llvm::ArrayRef args); - -template -LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results, - llvm::ArrayRef args); +LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results, + llvm::ArrayRef args); } // namespace builtin } // namespace pdl } // namespace mlir diff --git a/mlir/lib/Dialect/PDL/IR/Builtins.cpp b/mlir/lib/Dialect/PDL/IR/Builtins.cpp index 64f8593078b28..da8959c0033d1 100644 --- a/mlir/lib/Dialect/PDL/IR/Builtins.cpp +++ b/mlir/lib/Dialect/PDL/IR/Builtins.cpp @@ -59,8 +59,8 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter, } template -LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results, - ArrayRef args) { +LogicalResult static unaryOp(PatternRewriter &rewriter, PDLResultList &results, + ArrayRef args) { assert(args.size() == 1 && "Expected one operand for unary operation"); auto operandAttr = args[0].cast(); @@ -99,6 +99,27 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results, getIntegerAsAttr(APSInt(operandIntAttr.getValue(), false))); else results.push_back(getIntegerAsAttr(operandIntAttr.getAPSInt())); + } else if constexpr (T == UnaryOpKind::abs) { + if (integerType.isSigned()) { + // check overflow + if (operandIntAttr.getAPSInt() == + APSInt::getMinValue(integerType.getIntOrFloatBitWidth(), false)) + return failure(); + + results.push_back(rewriter.getIntegerAttr( + integerType, operandIntAttr.getValue().abs())); + return success(); + } + if (integerType.isSignless()) { + // Overflow should not be checked. + // Otherwise the purpose of signless integer is meaningless. + results.push_back(rewriter.getIntegerAttr( + integerType, operandIntAttr.getValue().abs())); + return success(); + } + // If unsigned, do nothing + results.push_back(operandIntAttr); + return success(); } else { llvm::llvm_unreachable_internal( "encountered an unsupported unary operator"); @@ -140,11 +161,17 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results, }) .Default([](Type /*type*/) { return failure(); }); } else if constexpr (T == UnaryOpKind::log2) { - auto minF32 = APFloat::getSmallest(llvm::APFloat::IEEEsingle()); - - APFloat resultFloat((float)operandFloatAttr.getValue().getExactLog2()); + results.push_back(rewriter.getFloatAttr( + operandFloatAttr.getType(), + (double)operandFloatAttr.getValue().getExactLog2())); + } else if constexpr (T == UnaryOpKind::abs) { + auto resultVal = operandFloatAttr.getValue(); + resultVal.clearSign(); results.push_back( - rewriter.getFloatAttr(operandFloatAttr.getType(), resultFloat)); + rewriter.getFloatAttr(operandFloatAttr.getType(), resultVal)); + } else { + llvm::llvm_unreachable_internal( + "encountered an unsupported unary operator"); } return success(); } @@ -152,8 +179,8 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results, } template -LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results, - llvm::ArrayRef args) { +LogicalResult static binaryOp(PatternRewriter &rewriter, PDLResultList &results, + llvm::ArrayRef args) { assert(args.size() == 2 && "Expected two operands for binary operation"); auto lhsAttr = args[0].cast(); auto rhsAttr = args[1].cast(); @@ -294,6 +321,10 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results, llvm::ArrayRef args) { return unaryOp(rewriter, results, args); } +LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results, + llvm::ArrayRef args) { + return unaryOp(rewriter, results, args); +} } // namespace builtin void registerBuiltins(PDLPatternModule &pdlPattern) { @@ -319,7 +350,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { pdlPattern.registerRewriteFunction("__builtin_subRewrite", sub); pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2); pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2); - + pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs); pdlPattern.registerConstraintFunctionWithResults("__builtin_mulConstraint", mul); pdlPattern.registerConstraintFunctionWithResults("__builtin_divConstraint", @@ -334,5 +365,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) { log2); pdlPattern.registerConstraintFunctionWithResults("__builtin_exp2Constraint", exp2); + pdlPattern.registerConstraintFunctionWithResults("__builtin_absConstraint", + abs); } } // namespace mlir::pdl \ No newline at end of file diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp index eff3627c0b59f..c6aa4622772c3 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.cpp @@ -377,6 +377,7 @@ Token Lexer::lexIdentifier(const char *tokStart) { .Case("_", Token::underscore) .Case("log2", Token::log2) .Case("exp2", Token::exp2) + .Case("abs", Token::abs) .Default(Token::identifier); return Token(kind, str); } diff --git a/mlir/lib/Tools/PDLL/Parser/Lexer.h b/mlir/lib/Tools/PDLL/Parser/Lexer.h index a6ce1bd8ac1fb..6699e884f49c9 100644 --- a/mlir/lib/Tools/PDLL/Parser/Lexer.h +++ b/mlir/lib/Tools/PDLL/Parser/Lexer.h @@ -80,14 +80,18 @@ class Token { equal, equal_arrow, semicolon, - /// Paired punctuation. - mul, + + /// Arithmetic. + abs, + add, div, + exp2, + log2, mod, - add, + mul, sub, - log2, - exp2, + + /// Paired punctuation. less, greater, l_brace, diff --git a/mlir/lib/Tools/PDLL/Parser/Parser.cpp b/mlir/lib/Tools/PDLL/Parser/Parser.cpp index f30d8d8492213..fcbb9f1029c62 100644 --- a/mlir/lib/Tools/PDLL/Parser/Parser.cpp +++ b/mlir/lib/Tools/PDLL/Parser/Parser.cpp @@ -334,7 +334,7 @@ class Parser { FailureOr parseLogicalAndExpr(); FailureOr parseEqualityExpr(); FailureOr parseRelationExpr(); - FailureOr parseExp2Log2Expr(); + FailureOr parseExp2Log2AbsExpr(); FailureOr parseAddSubExpr(); FailureOr parseMulDivModExpr(); FailureOr parseLogicalNotExpr(); @@ -624,6 +624,7 @@ class Parser { ast::UserRewriteDecl *subRewrite; ast::UserRewriteDecl *log2Rewrite; ast::UserRewriteDecl *exp2Rewrite; + ast::UserRewriteDecl *absRewrite; ast::UserConstraintDecl *mulConstraint; ast::UserConstraintDecl *divConstraint; ast::UserConstraintDecl *modConstraint; @@ -631,6 +632,7 @@ class Parser { ast::UserConstraintDecl *subConstraint; ast::UserConstraintDecl *log2Constraint; ast::UserConstraintDecl *exp2Constraint; + ast::UserConstraintDecl *absConstraint; } builtins{}; }; } // namespace @@ -701,6 +703,8 @@ void Parser::declareBuiltins() { "__builtin_log2Rewrite", {"Attr"}, true); builtins.exp2Rewrite = declareBuiltin( "__builtin_exp2Rewrite", {"Attr"}, true); + builtins.absRewrite = declareBuiltin( + "__builtin_absRewrite", {"Attr"}, true); builtins.mulConstraint = declareBuiltin( "__builtin_mulConstraint", {"lhs", "rhs"}, true); builtins.divConstraint = declareBuiltin( @@ -715,6 +719,8 @@ void Parser::declareBuiltins() { "__builtin_log2Constraint", {"Attr"}, true); builtins.exp2Constraint = declareBuiltin( "__builtin_exp2Constraint", {"Attr"}, true); + builtins.absConstraint = declareBuiltin( + "__builtin_absConstraint", {"Attr"}, true); } FailureOr Parser::parseModule() { @@ -2030,7 +2036,7 @@ FailureOr Parser::parseAddSubExpr() { } FailureOr Parser::parseMulDivModExpr() { - auto lhs = parseExp2Log2Expr(); + auto lhs = parseExp2Log2AbsExpr(); if (failed(lhs)) return failure(); @@ -2038,7 +2044,7 @@ FailureOr Parser::parseMulDivModExpr() { switch (curToken.getKind()) { case Token::mul: { consumeToken(); - auto rhs = parseExp2Log2Expr(); + auto rhs = parseExp2Log2AbsExpr(); if (failed(rhs)) return failure(); SmallVector args{*lhs, *rhs}; @@ -2058,7 +2064,7 @@ FailureOr Parser::parseMulDivModExpr() { } case Token::div: { consumeToken(); - auto rhs = parseExp2Log2Expr(); + auto rhs = parseExp2Log2AbsExpr(); if (failed(rhs)) return failure(); SmallVector args{*lhs, *rhs}; @@ -2078,7 +2084,7 @@ FailureOr Parser::parseMulDivModExpr() { } case Token::mod: { consumeToken(); - auto rhs = parseExp2Log2Expr(); + auto rhs = parseExp2Log2AbsExpr(); if (failed(rhs)) return failure(); SmallVector args{*lhs, *rhs}; @@ -2100,7 +2106,7 @@ FailureOr Parser::parseMulDivModExpr() { } } -FailureOr Parser::parseExp2Log2Expr() { +FailureOr Parser::parseExp2Log2AbsExpr() { FailureOr expr = nullptr; switch (curToken.getKind()) { @@ -2144,6 +2150,26 @@ FailureOr Parser::parseExp2Log2Expr() { : createBuiltinCall(curToken.getLoc(), builtins.exp2Constraint, {*expr}); } + case Token::abs: { + consumeToken(); + consumeToken(Token::l_paren); + expr = parseAddSubExpr(); + if (failed(expr)) + return failure(); + + // Check if it is in rewrite section but not in the let statement + bool inRewriteSection = parserContext == ParserContext::Rewrite; + if (inRewriteSection && nativeOperatorContext != NativeOperatorContext::Let) + return emitError("cannot evaluate abs operator in rewrite section. " + "Assign to a variable with `let`"); + + consumeToken(Token::r_paren); + return inRewriteSection + ? createBuiltinCall(curToken.getLoc(), builtins.absRewrite, + {*expr}) + : createBuiltinCall(curToken.getLoc(), builtins.absConstraint, + {*expr}); + } default: return parseLogicalNotExpr(); } diff --git a/mlir/test/mlir-pdll-lsp-server/completion.test b/mlir/test/mlir-pdll-lsp-server/completion.test index b336bc2081059..74a69928165e3 100644 --- a/mlir/test/mlir-pdll-lsp-server/completion.test +++ b/mlir/test/mlir-pdll-lsp-server/completion.test @@ -208,6 +208,12 @@ // CHECK-NEXT: "kind": 8, // CHECK-NEXT: "label": "__builtin_log2Constraint", // CHECK-NEXT: "sortText": "2___builtin_log2Constraint" +// CHECK-NEXT: }, +// CHECK-NEXT: { +// CHECK-NEXT: "detail": "(Attr: Attr) -> Attr", +// CHECK-NEXT: "kind": 8, +// CHECK-NEXT: "label": "__builtin_absConstraint", +// CHECK-NEXT: "sortText": "2___builtin_absConstraint" // CHECK-NEXT: } // CHECK-NEXT: ] // CHECK-NEXT: } diff --git a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll index 6bcf7b3789f03..b801c6177b63c 100644 --- a/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll +++ b/mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll @@ -309,6 +309,7 @@ Pattern TestAdd { // CHECK: apply_native_constraint "__builtin_modConstraint"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute // CHECK: apply_native_constraint "__builtin_log2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute // CHECK: apply_native_constraint "__builtin_exp2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute +// CHECK: apply_native_constraint "__builtin_absConstraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute Pattern TestOperatorsNotInRewriteSection { let a : Attr = attr<"4 : i32">; @@ -320,6 +321,7 @@ Pattern TestOperatorsNotInRewriteSection { let modConstraint : Attr = a % b; let log2Constraint : Attr = log2(a); let exp2Constraint : Attr = exp2(a); + let absConstraint : Attr = abs(a); replace op with op; } @@ -336,6 +338,7 @@ Pattern TestOperatorsNotInRewriteSection { // CHECK: apply_native_rewrite "__builtin_modRewrite"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute // CHECK: apply_native_rewrite "__builtin_log2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute // CHECK: apply_native_rewrite "__builtin_exp2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute +// CHECK: apply_native_rewrite "__builtin_absRewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute Pattern TestOperatorsInRewriteSection { let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange); rewrite root with { @@ -348,6 +351,7 @@ Pattern TestOperatorsInRewriteSection { let modRewrite : Attr = a % b; let log2Rewrite : Attr = log2(a); let exp2Rewrite : Attr = exp2(a); + let absRewrite : Attr = abs(a); erase root; }; } diff --git a/mlir/test/mlir-pdll/Parser/expr-failure.pdll b/mlir/test/mlir-pdll/Parser/expr-failure.pdll index 6af66de88d810..e50bb3ff322fc 100644 --- a/mlir/test/mlir-pdll/Parser/expr-failure.pdll +++ b/mlir/test/mlir-pdll/Parser/expr-failure.pdll @@ -581,6 +581,17 @@ Pattern { // ----- +Pattern { + // CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let` + let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange); + rewrite root with { + abs(attr<"-4 : si32">); + erase root; + }; +} + +// ----- + // check llvm::saveAndRestore works Pattern { // CHECK: cannot evaluate exp2 operator in rewrite section. Assign to a variable with `let` @@ -590,4 +601,30 @@ Pattern { exp2(attr<"4 : i32">); erase root; }; +} + +// ----- + +// check llvm::saveAndRestore works +Pattern { + // CHECK: cannot evaluate log2 operator in rewrite section. Assign to a variable with `let` + let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange); + rewrite root with { + let a : Attr = attr<"4 : i32"> + attr<"5 : i32">; + log2(attr<"4 : i32">); + erase root; + }; +} + +// ----- + +// check llvm::saveAndRestore works +Pattern { + // CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let` + let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange); + rewrite root with { + let a : Attr = attr<"4 : i32"> + attr<"5 : i32">; + abs(attr<"4 : i32">); + erase root; + }; } \ No newline at end of file diff --git a/mlir/test/mlir-pdll/Parser/expr.pdll b/mlir/test/mlir-pdll/Parser/expr.pdll index 71868c8a0b3af..62c597fb6ac50 100644 --- a/mlir/test/mlir-pdll/Parser/expr.pdll +++ b/mlir/test/mlir-pdll/Parser/expr.pdll @@ -450,6 +450,7 @@ Pattern { // CHECK: UserConstraintDecl {{.*}} Name<__builtin_modConstraint> ResultType // CHECK: UserConstraintDecl {{.*}} Name<__builtin_log2Constraint> ResultType // CHECK: UserConstraintDecl {{.*}} Name<__builtin_exp2Constraint> ResultType +// CHECK: UserConstraintDecl {{.*}} Name<__builtin_absConstraint> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_addRewrite> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_subRewrite> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_mulRewrite> ResultType @@ -457,15 +458,19 @@ Pattern { // CHECK: UserRewriteDecl {{.*}} Name<__builtin_modRewrite> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_log2Rewrite> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_exp2Rewrite> ResultType +// CHECK: UserRewriteDecl {{.*}} Name<__builtin_absRewrite> ResultType Pattern { let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange); - let addConstraint : Attr = attr<"4 : i32"> + attr<"5 : i32">; - let subConstraint : Attr = attr<"4 : i32"> - attr<"5 : i32">; - let mulConstraint : Attr = attr<"4 : i32"> * attr<"5 : i32">; - let divConstraint : Attr = attr<"4 : i32"> / attr<"5 : i32">; - let modConstraint : Attr = attr<"4 : i32"> % attr<"5 : i32">; - let log2Constraint : Attr = log2(attr<"4 : i32">); - let exp2Constraint : Attr = exp2(attr<"4 : i32">); + let a : Attr = attr<"4 : i32">; + let b : Attr = attr<"5 : i32">; + let addConstraint : Attr = a + b; + let subConstraint : Attr = a - b; + let mulConstraint : Attr = a * b; + let divConstraint : Attr = a / b; + let modConstraint : Attr = a % b; + let log2Constraint : Attr = log2(a); + let exp2Constraint : Attr = exp2(b); + let absConstraint : Attr = abs(attr<"-4 : si32">); rewrite root with { let addRewrite : Attr = attr<"4 : i32"> + attr<"5 : i32">; let subRewrite : Attr = attr<"4 : i32"> - attr<"5 : i32">; @@ -474,6 +479,7 @@ Pattern { let modRewrite : Attr = attr<"4 : i32"> % attr<"5 : i32">; let log2Rewrite : Attr = log2(attr<"4 : i32">); let exp2Rewrite : Attr = exp2(attr<"4 : i32">); + let absRewrite : Attr = abs(attr<"-4 : si32">); erase root; }; } @@ -488,6 +494,7 @@ Pattern { // CHECK: UserConstraintDecl {{.*}} Name<__builtin_modConstraint> ResultType // CHECK: UserConstraintDecl {{.*}} Name<__builtin_log2Constraint> ResultType // CHECK: UserConstraintDecl {{.*}} Name<__builtin_exp2Constraint> ResultType +// CHECK: UserConstraintDecl {{.*}} Name<__builtin_absConstraint> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_addRewrite> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_subRewrite> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_mulRewrite> ResultType @@ -495,6 +502,7 @@ Pattern { // CHECK: UserRewriteDecl {{.*}} Name<__builtin_modRewrite> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_log2Rewrite> ResultType // CHECK: UserRewriteDecl {{.*}} Name<__builtin_exp2Rewrite> ResultType +// CHECK: UserRewriteDecl {{.*}} Name<__builtin_absRewrite> ResultType Constraint TestConstraint(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) { let a : Attr = attr<"4 : i32">; let b : Attr = attr<"5 : i32">; @@ -505,6 +513,7 @@ Constraint TestConstraint(attr: Attr, op: Op, type: Type, value: Value, typeRang let modConstraint : Attr = a % b; let log2Constraint : Attr = log2(a); let exp2Constraint : Attr = exp2(b); + let absConstraint : Attr = abs(attr<"-4 : si32">); } Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) { @@ -517,6 +526,7 @@ Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: Typ let modRewrite : Attr = c % d; let log2Rewrite : Attr = log2(c); let exp2Rewrite : Attr = exp2(d); + let absRewrite : Attr = abs(attr<"-4 : si32">); } Pattern TestOperatorContext { @@ -535,10 +545,14 @@ Pattern TestOperatorContext { // CHECK: -UserConstraintDecl {{.*}} Name<__builtin_addConstraint> ResultType // CHECK: -UserConstraintDecl {{.*}} Name<__builtin_exp2Constraint> ResultType // CHECK: Arguments -// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_addConstraint> ResultType +// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_addConstraint> ResultType +// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_absConstraint> ResultType +// CHECK: Arguments +// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_addConstraint> ResultType Pattern { let log2Constraint : Attr = log2(attr<"4 : i32"> + attr<"4 : i32">); let exp2Constraint : Attr = exp2(attr<"2 : i32"> + attr<"2 : i32">); + let absConstraint : Attr = abs(attr<"-4 : si32"> + attr<"-2 : si32">); erase _: Op; } @@ -553,6 +567,7 @@ Pattern { // CHECK: UserConstraintDecl {{.*}} Name<__builtin_modConstraint> ResultType // CHECK: UserConstraintDecl {{.*}} Name<__builtin_log2Constraint> ResultType // CHECK: UserConstraintDecl {{.*}} Name<__builtin_exp2Constraint> ResultType +// CHECK: UserConstraintDecl {{.*}} Name<__builtin_absConstraint> ResultType Pattern { attr<"4 : i32"> + attr<"5 : i32">; attr<"4 : i32"> - attr<"5 : i32">; @@ -561,6 +576,7 @@ Pattern { attr<"4 : i32"> % attr<"5 : i32">; log2(attr<"4 : i32">); exp2(attr<"4 : i32">); + abs(attr<"-4 : si32">); erase _: Op; } @@ -614,4 +630,4 @@ Pattern { let c : Attr = attr<"6 : i32">; let x : Attr = a + b * c; erase _: Op; -} +} \ No newline at end of file diff --git a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp index d89312b5c6880..e31fa6f53a8a1 100644 --- a/mlir/unittests/Dialect/PDL/BuiltinTest.cpp +++ b/mlir/unittests/Dialect/PDL/BuiltinTest.cpp @@ -617,12 +617,12 @@ TEST_F(BuiltinTest, log2) { "log2 of an integer is expected to return an exact integer."); } - auto fourF32 = rewriter.getF32FloatAttr(4.0); + auto fourF16 = rewriter.getF16FloatAttr(4.0); // check correctness { TestPDLResultList results(1); - EXPECT_TRUE(builtin::log2(rewriter, results, {fourF32}).succeeded()); + EXPECT_TRUE(builtin::log2(rewriter, results, {fourF16}).succeeded()); PDLValue result = results.getResults()[0]; EXPECT_EQ( @@ -709,4 +709,71 @@ TEST_F(BuiltinTest, exp2) { builtin::exp2(rewriter, results, {minusHundredFiftyF32}).failed()); } } + +TEST_F(BuiltinTest, abs) { + // signed integer overflow + { + auto SI8Type = rewriter.getIntegerType(8, true); + auto value = rewriter.getIntegerAttr(SI8Type, -128); + TestPDLResultList results(1); + EXPECT_TRUE(builtin::abs(rewriter, results, {value}).failed()); + } + + // signed integer correctness + { + auto value = rewriter.getSI32IntegerAttr(-1); + TestPDLResultList results(1); + EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded()); + auto result = results.getResults()[0]; + EXPECT_EQ( + cast(result.cast()).getValue().getSExtValue(), + 1); + } + + // unsigned integer + { + auto value = rewriter.getUI32IntegerAttr(1); + TestPDLResultList results(1); + EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded()); + auto result = results.getResults()[0]; + EXPECT_EQ( + cast(result.cast()).getValue().getZExtValue(), + (uint64_t)1); + } + + // signless integer + { + auto value = rewriter.getI8IntegerAttr(-7); + TestPDLResultList results(1); + EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded()); + auto result = results.getResults()[0]; + EXPECT_EQ( + cast(result.cast()).getValue().getSExtValue(), + 7); + } + + // signless integer: edge case -128 + // Overflow should not be checked + // otherwise the purpose of signless integer is meaningless + { + auto value = rewriter.getI8IntegerAttr(-128); + TestPDLResultList results(1); + EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded()); + auto result = results.getResults()[0]; + EXPECT_EQ( + cast(result.cast()).getValue().getSExtValue(), + -128); + } + + // float + { + auto value = rewriter.getF32FloatAttr(-1.0); + TestPDLResultList results(1); + EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded()); + auto result = results.getResults()[0]; + EXPECT_EQ( + cast(result.cast()).getValue().convertToFloat(), + 1.0); + } +} } // namespace