Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PDLL native operator abs #173

Merged
merged 6 commits into from
May 3, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 6 additions & 10 deletions mlir/include/mlir/Dialect/PDL/IR/Builtins.h
Original file line number Diff line number Diff line change
@@ -28,15 +28,16 @@ void registerBuiltins(PDLPatternModule &pdlPattern);
namespace builtin {
enum class BinaryOpKind {
add,
sub,
mul,
div,
mod,
mul,
sub,
};

enum class UnaryOpKind {
log2,
abs,
exp2,
log2,
};

LogicalResult createDictionaryAttr(PatternRewriter &rewriter,
@@ -48,9 +49,6 @@ LogicalResult addEntryToDictionaryAttr(PatternRewriter &rewriter,
Attribute createArrayAttr(PatternRewriter &rewriter);
Attribute addElemToArrayAttr(PatternRewriter &rewriter, Attribute attr,
Attribute element);
template <BinaryOpKind T>
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult mul(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult div(PatternRewriter &rewriter, PDLResultList &results,
@@ -65,10 +63,8 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult exp2(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);

template <BinaryOpKind T>
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args);
} // namespace builtin
} // namespace pdl
} // namespace mlir
51 changes: 42 additions & 9 deletions mlir/lib/Dialect/PDL/IR/Builtins.cpp
Original file line number Diff line number Diff line change
@@ -59,8 +59,8 @@ mlir::Attribute addElemToArrayAttr(mlir::PatternRewriter &rewriter,
}

template <UnaryOpKind T>
LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
LogicalResult static unaryOp(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
assert(args.size() == 1 && "Expected one operand for unary operation");
auto operandAttr = args[0].cast<Attribute>();

@@ -99,6 +99,27 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
getIntegerAsAttr(APSInt(operandIntAttr.getValue(), false)));
else
results.push_back(getIntegerAsAttr(operandIntAttr.getAPSInt()));
} else if constexpr (T == UnaryOpKind::abs) {
if (integerType.isSigned()) {
// check overflow
if (operandIntAttr.getAPSInt() ==
APSInt::getMinValue(integerType.getIntOrFloatBitWidth(), false))
return failure();

results.push_back(rewriter.getIntegerAttr(
integerType, operandIntAttr.getValue().abs()));
return success();
}
if (integerType.isSignless()) {
// Overflow should not be checked.
// Otherwise the purpose of signless integer is meaningless.
results.push_back(rewriter.getIntegerAttr(
integerType, operandIntAttr.getValue().abs()));
return success();
}
// If unsigned, do nothing
results.push_back(operandIntAttr);
return success();
} else {
llvm::llvm_unreachable_internal(
"encountered an unsupported unary operator");
@@ -140,20 +161,26 @@ LogicalResult unaryOp(PatternRewriter &rewriter, PDLResultList &results,
})
.Default([](Type /*type*/) { return failure(); });
} else if constexpr (T == UnaryOpKind::log2) {
auto minF32 = APFloat::getSmallest(llvm::APFloat::IEEEsingle());

APFloat resultFloat((float)operandFloatAttr.getValue().getExactLog2());
results.push_back(rewriter.getFloatAttr(
operandFloatAttr.getType(),
(double)operandFloatAttr.getValue().getExactLog2()));
} else if constexpr (T == UnaryOpKind::abs) {
auto resultVal = operandFloatAttr.getValue();
resultVal.clearSign();
results.push_back(
rewriter.getFloatAttr(operandFloatAttr.getType(), resultFloat));
rewriter.getFloatAttr(operandFloatAttr.getType(), resultVal));
} else {
llvm::llvm_unreachable_internal(
"encountered an unsupported unary operator");
}
return success();
}
return failure();
}

template <BinaryOpKind T>
LogicalResult binaryOp(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args) {
LogicalResult static binaryOp(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args) {
assert(args.size() == 2 && "Expected two operands for binary operation");
auto lhsAttr = args[0].cast<Attribute>();
auto rhsAttr = args[1].cast<Attribute>();
@@ -294,6 +321,10 @@ LogicalResult log2(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args) {
return unaryOp<UnaryOpKind::log2>(rewriter, results, args);
}
LogicalResult abs(PatternRewriter &rewriter, PDLResultList &results,
llvm::ArrayRef<PDLValue> args) {
return unaryOp<UnaryOpKind::abs>(rewriter, results, args);
}
} // namespace builtin

void registerBuiltins(PDLPatternModule &pdlPattern) {
@@ -319,7 +350,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
pdlPattern.registerRewriteFunction("__builtin_subRewrite", sub);
pdlPattern.registerRewriteFunction("__builtin_log2Rewrite", log2);
pdlPattern.registerRewriteFunction("__builtin_exp2Rewrite", exp2);

pdlPattern.registerRewriteFunction("__builtin_absRewrite", abs);
pdlPattern.registerConstraintFunctionWithResults("__builtin_mulConstraint",
mul);
pdlPattern.registerConstraintFunctionWithResults("__builtin_divConstraint",
@@ -334,5 +365,7 @@ void registerBuiltins(PDLPatternModule &pdlPattern) {
log2);
pdlPattern.registerConstraintFunctionWithResults("__builtin_exp2Constraint",
exp2);
pdlPattern.registerConstraintFunctionWithResults("__builtin_absConstraint",
abs);
}
} // namespace mlir::pdl
1 change: 1 addition & 0 deletions mlir/lib/Tools/PDLL/Parser/Lexer.cpp
Original file line number Diff line number Diff line change
@@ -377,6 +377,7 @@ Token Lexer::lexIdentifier(const char *tokStart) {
.Case("_", Token::underscore)
.Case("log2", Token::log2)
.Case("exp2", Token::exp2)
.Case("abs", Token::abs)
.Default(Token::identifier);
return Token(kind, str);
}
14 changes: 9 additions & 5 deletions mlir/lib/Tools/PDLL/Parser/Lexer.h
Original file line number Diff line number Diff line change
@@ -80,14 +80,18 @@ class Token {
equal,
equal_arrow,
semicolon,
/// Paired punctuation.
mul,

/// Arithmetic.
abs,
add,
div,
exp2,
log2,
mod,
add,
mul,
sub,
log2,
exp2,

/// Paired punctuation.
less,
greater,
l_brace,
38 changes: 32 additions & 6 deletions mlir/lib/Tools/PDLL/Parser/Parser.cpp
Original file line number Diff line number Diff line change
@@ -334,7 +334,7 @@ class Parser {
FailureOr<ast::Expr *> parseLogicalAndExpr();
FailureOr<ast::Expr *> parseEqualityExpr();
FailureOr<ast::Expr *> parseRelationExpr();
FailureOr<ast::Expr *> parseExp2Log2Expr();
FailureOr<ast::Expr *> parseExp2Log2AbsExpr();
FailureOr<ast::Expr *> parseAddSubExpr();
FailureOr<ast::Expr *> parseMulDivModExpr();
FailureOr<ast::Expr *> parseLogicalNotExpr();
@@ -624,13 +624,15 @@ class Parser {
ast::UserRewriteDecl *subRewrite;
ast::UserRewriteDecl *log2Rewrite;
ast::UserRewriteDecl *exp2Rewrite;
ast::UserRewriteDecl *absRewrite;
ast::UserConstraintDecl *mulConstraint;
ast::UserConstraintDecl *divConstraint;
ast::UserConstraintDecl *modConstraint;
ast::UserConstraintDecl *addConstraint;
ast::UserConstraintDecl *subConstraint;
ast::UserConstraintDecl *log2Constraint;
ast::UserConstraintDecl *exp2Constraint;
ast::UserConstraintDecl *absConstraint;
} builtins{};
};
} // namespace
@@ -701,6 +703,8 @@ void Parser::declareBuiltins() {
"__builtin_log2Rewrite", {"Attr"}, true);
builtins.exp2Rewrite = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_exp2Rewrite", {"Attr"}, true);
builtins.absRewrite = declareBuiltin<ast::UserRewriteDecl>(
"__builtin_absRewrite", {"Attr"}, true);
builtins.mulConstraint = declareBuiltin<ast::UserConstraintDecl>(
"__builtin_mulConstraint", {"lhs", "rhs"}, true);
builtins.divConstraint = declareBuiltin<ast::UserConstraintDecl>(
@@ -715,6 +719,8 @@ void Parser::declareBuiltins() {
"__builtin_log2Constraint", {"Attr"}, true);
builtins.exp2Constraint = declareBuiltin<ast::UserConstraintDecl>(
"__builtin_exp2Constraint", {"Attr"}, true);
builtins.absConstraint = declareBuiltin<ast::UserConstraintDecl>(
"__builtin_absConstraint", {"Attr"}, true);
}

FailureOr<ast::Module *> Parser::parseModule() {
@@ -2030,15 +2036,15 @@ FailureOr<ast::Expr *> Parser::parseAddSubExpr() {
}

FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
auto lhs = parseExp2Log2Expr();
auto lhs = parseExp2Log2AbsExpr();
if (failed(lhs))
return failure();

for (;;) {
switch (curToken.getKind()) {
case Token::mul: {
consumeToken();
auto rhs = parseExp2Log2Expr();
auto rhs = parseExp2Log2AbsExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2058,7 +2064,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
}
case Token::div: {
consumeToken();
auto rhs = parseExp2Log2Expr();
auto rhs = parseExp2Log2AbsExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2078,7 +2084,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
}
case Token::mod: {
consumeToken();
auto rhs = parseExp2Log2Expr();
auto rhs = parseExp2Log2AbsExpr();
if (failed(rhs))
return failure();
SmallVector<ast::Expr *> args{*lhs, *rhs};
@@ -2100,7 +2106,7 @@ FailureOr<ast::Expr *> Parser::parseMulDivModExpr() {
}
}

FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
FailureOr<ast::Expr *> Parser::parseExp2Log2AbsExpr() {
FailureOr<ast::Expr *> expr = nullptr;

switch (curToken.getKind()) {
@@ -2144,6 +2150,26 @@ FailureOr<ast::Expr *> Parser::parseExp2Log2Expr() {
: createBuiltinCall(curToken.getLoc(), builtins.exp2Constraint,
{*expr});
}
case Token::abs: {
consumeToken();
consumeToken(Token::l_paren);
expr = parseAddSubExpr();
if (failed(expr))
return failure();

// Check if it is in rewrite section but not in the let statement
bool inRewriteSection = parserContext == ParserContext::Rewrite;
if (inRewriteSection && nativeOperatorContext != NativeOperatorContext::Let)
return emitError("cannot evaluate abs operator in rewrite section. "
"Assign to a variable with `let`");

consumeToken(Token::r_paren);
return inRewriteSection
? createBuiltinCall(curToken.getLoc(), builtins.absRewrite,
{*expr})
: createBuiltinCall(curToken.getLoc(), builtins.absConstraint,
{*expr});
}
default:
return parseLogicalNotExpr();
}
6 changes: 6 additions & 0 deletions mlir/test/mlir-pdll-lsp-server/completion.test
Original file line number Diff line number Diff line change
@@ -208,6 +208,12 @@
// CHECK-NEXT: "kind": 8,
// CHECK-NEXT: "label": "__builtin_log2Constraint",
// CHECK-NEXT: "sortText": "2___builtin_log2Constraint"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "(Attr: Attr) -> Attr",
// CHECK-NEXT: "kind": 8,
// CHECK-NEXT: "label": "__builtin_absConstraint",
// CHECK-NEXT: "sortText": "2___builtin_absConstraint"
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
4 changes: 4 additions & 0 deletions mlir/test/mlir-pdll/CodeGen/MLIR/expr.pdll
Original file line number Diff line number Diff line change
@@ -309,6 +309,7 @@ Pattern TestAdd {
// CHECK: apply_native_constraint "__builtin_modConstraint"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_constraint "__builtin_log2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_constraint "__builtin_exp2Constraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_constraint "__builtin_absConstraint"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute

Pattern TestOperatorsNotInRewriteSection {
let a : Attr = attr<"4 : i32">;
@@ -320,6 +321,7 @@ Pattern TestOperatorsNotInRewriteSection {
let modConstraint : Attr = a % b;
let log2Constraint : Attr = log2(a);
let exp2Constraint : Attr = exp2(a);
let absConstraint : Attr = abs(a);
replace op<test.simple> with op<test.success>;
}

@@ -336,6 +338,7 @@ Pattern TestOperatorsNotInRewriteSection {
// CHECK: apply_native_rewrite "__builtin_modRewrite"(%[[VAL_0]], %[[VAL_1]] : !pdl.attribute, !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_rewrite "__builtin_log2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_rewrite "__builtin_exp2Rewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
// CHECK: apply_native_rewrite "__builtin_absRewrite"(%[[VAL_0]] : !pdl.attribute) : !pdl.attribute
Pattern TestOperatorsInRewriteSection {
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
rewrite root with {
@@ -348,6 +351,7 @@ Pattern TestOperatorsInRewriteSection {
let modRewrite : Attr = a % b;
let log2Rewrite : Attr = log2(a);
let exp2Rewrite : Attr = exp2(a);
let absRewrite : Attr = abs(a);
erase root;
};
}
37 changes: 37 additions & 0 deletions mlir/test/mlir-pdll/Parser/expr-failure.pdll
Original file line number Diff line number Diff line change
@@ -581,6 +581,17 @@ Pattern {

// -----

Pattern {
// CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let`
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
rewrite root with {
abs(attr<"-4 : si32">);
erase root;
};
}

// -----

// check llvm::saveAndRestore works
Pattern {
// CHECK: cannot evaluate exp2 operator in rewrite section. Assign to a variable with `let`
@@ -590,4 +601,30 @@ Pattern {
exp2(attr<"4 : i32">);
erase root;
};
}

// -----

// check llvm::saveAndRestore works
Pattern {
// CHECK: cannot evaluate log2 operator in rewrite section. Assign to a variable with `let`
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
rewrite root with {
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
log2(attr<"4 : i32">);
erase root;
};
}

// -----

// check llvm::saveAndRestore works
Pattern {
// CHECK: cannot evaluate abs operator in rewrite section. Assign to a variable with `let`
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
rewrite root with {
let a : Attr = attr<"4 : i32"> + attr<"5 : i32">;
abs(attr<"4 : i32">);
erase root;
};
}
34 changes: 25 additions & 9 deletions mlir/test/mlir-pdll/Parser/expr.pdll
Original file line number Diff line number Diff line change
@@ -450,22 +450,27 @@ Pattern {
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_modConstraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_log2Constraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_exp2Constraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_absConstraint> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_addRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_subRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_mulRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_divRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_modRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_log2Rewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_exp2Rewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_absRewrite> ResultType<Attr>
Pattern {
let root = op<>(operand: Value, operands: ValueRange) -> (type: Type, types: TypeRange);
let addConstraint : Attr = attr<"4 : i32"> + attr<"5 : i32">;
let subConstraint : Attr = attr<"4 : i32"> - attr<"5 : i32">;
let mulConstraint : Attr = attr<"4 : i32"> * attr<"5 : i32">;
let divConstraint : Attr = attr<"4 : i32"> / attr<"5 : i32">;
let modConstraint : Attr = attr<"4 : i32"> % attr<"5 : i32">;
let log2Constraint : Attr = log2(attr<"4 : i32">);
let exp2Constraint : Attr = exp2(attr<"4 : i32">);
let a : Attr = attr<"4 : i32">;
let b : Attr = attr<"5 : i32">;
let addConstraint : Attr = a + b;
let subConstraint : Attr = a - b;
let mulConstraint : Attr = a * b;
let divConstraint : Attr = a / b;
let modConstraint : Attr = a % b;
let log2Constraint : Attr = log2(a);
let exp2Constraint : Attr = exp2(b);
let absConstraint : Attr = abs(attr<"-4 : si32">);
rewrite root with {
let addRewrite : Attr = attr<"4 : i32"> + attr<"5 : i32">;
let subRewrite : Attr = attr<"4 : i32"> - attr<"5 : i32">;
@@ -474,6 +479,7 @@ Pattern {
let modRewrite : Attr = attr<"4 : i32"> % attr<"5 : i32">;
let log2Rewrite : Attr = log2(attr<"4 : i32">);
let exp2Rewrite : Attr = exp2(attr<"4 : i32">);
let absRewrite : Attr = abs(attr<"-4 : si32">);
erase root;
};
}
@@ -488,13 +494,15 @@ Pattern {
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_modConstraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_log2Constraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_exp2Constraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_absConstraint> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_addRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_subRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_mulRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_divRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_modRewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_log2Rewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_exp2Rewrite> ResultType<Attr>
// CHECK: UserRewriteDecl {{.*}} Name<__builtin_absRewrite> ResultType<Attr>
Constraint TestConstraint(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) {
let a : Attr = attr<"4 : i32">;
let b : Attr = attr<"5 : i32">;
@@ -505,6 +513,7 @@ Constraint TestConstraint(attr: Attr, op: Op, type: Type, value: Value, typeRang
let modConstraint : Attr = a % b;
let log2Constraint : Attr = log2(a);
let exp2Constraint : Attr = exp2(b);
let absConstraint : Attr = abs(attr<"-4 : si32">);
}

Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: TypeRange, valueRange: ValueRange) {
@@ -517,6 +526,7 @@ Rewrite TestRewrite(attr: Attr, op: Op, type: Type, value: Value, typeRange: Typ
let modRewrite : Attr = c % d;
let log2Rewrite : Attr = log2(c);
let exp2Rewrite : Attr = exp2(d);
let absRewrite : Attr = abs(attr<"-4 : si32">);
}

Pattern TestOperatorContext {
@@ -535,10 +545,14 @@ Pattern TestOperatorContext {
// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_addConstraint> ResultType<Attr>
// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_exp2Constraint> ResultType<Attr>
// CHECK: Arguments
// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_addConstraint> ResultType<Attr>
// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_addConstraint> ResultType<Attr>
// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_absConstraint> ResultType<Attr>
// CHECK: Arguments
// CHECK: -UserConstraintDecl {{.*}} Name<__builtin_addConstraint> ResultType<Attr>
Pattern {
let log2Constraint : Attr = log2(attr<"4 : i32"> + attr<"4 : i32">);
let exp2Constraint : Attr = exp2(attr<"2 : i32"> + attr<"2 : i32">);
let absConstraint : Attr = abs(attr<"-4 : si32"> + attr<"-2 : si32">);
erase _: Op;
}

@@ -553,6 +567,7 @@ Pattern {
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_modConstraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_log2Constraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_exp2Constraint> ResultType<Attr>
// CHECK: UserConstraintDecl {{.*}} Name<__builtin_absConstraint> ResultType<Attr>
Pattern {
attr<"4 : i32"> + attr<"5 : i32">;
attr<"4 : i32"> - attr<"5 : i32">;
@@ -561,6 +576,7 @@ Pattern {
attr<"4 : i32"> % attr<"5 : i32">;
log2(attr<"4 : i32">);
exp2(attr<"4 : i32">);
abs(attr<"-4 : si32">);
erase _: Op;
}

@@ -614,4 +630,4 @@ Pattern {
let c : Attr = attr<"6 : i32">;
let x : Attr = a + b * c;
erase _: Op;
}
}
71 changes: 69 additions & 2 deletions mlir/unittests/Dialect/PDL/BuiltinTest.cpp
Original file line number Diff line number Diff line change
@@ -617,12 +617,12 @@ TEST_F(BuiltinTest, log2) {
"log2 of an integer is expected to return an exact integer.");
}

auto fourF32 = rewriter.getF32FloatAttr(4.0);
auto fourF16 = rewriter.getF16FloatAttr(4.0);

// check correctness
{
TestPDLResultList results(1);
EXPECT_TRUE(builtin::log2(rewriter, results, {fourF32}).succeeded());
EXPECT_TRUE(builtin::log2(rewriter, results, {fourF16}).succeeded());

PDLValue result = results.getResults()[0];
EXPECT_EQ(
@@ -709,4 +709,71 @@ TEST_F(BuiltinTest, exp2) {
builtin::exp2(rewriter, results, {minusHundredFiftyF32}).failed());
}
}

TEST_F(BuiltinTest, abs) {
// signed integer overflow
{
auto SI8Type = rewriter.getIntegerType(8, true);
auto value = rewriter.getIntegerAttr(SI8Type, -128);
TestPDLResultList results(1);
EXPECT_TRUE(builtin::abs(rewriter, results, {value}).failed());
}

// signed integer correctness
{
auto value = rewriter.getSI32IntegerAttr(-1);
TestPDLResultList results(1);
EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded());
auto result = results.getResults()[0];
EXPECT_EQ(
cast<IntegerAttr>(result.cast<Attribute>()).getValue().getSExtValue(),
1);
}

// unsigned integer
{
auto value = rewriter.getUI32IntegerAttr(1);
TestPDLResultList results(1);
EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded());
auto result = results.getResults()[0];
EXPECT_EQ(
cast<IntegerAttr>(result.cast<Attribute>()).getValue().getZExtValue(),
(uint64_t)1);
}

// signless integer
{
auto value = rewriter.getI8IntegerAttr(-7);
TestPDLResultList results(1);
EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded());
auto result = results.getResults()[0];
EXPECT_EQ(
cast<IntegerAttr>(result.cast<Attribute>()).getValue().getSExtValue(),
7);
}

// signless integer: edge case -128
// Overflow should not be checked
// otherwise the purpose of signless integer is meaningless
{
auto value = rewriter.getI8IntegerAttr(-128);
TestPDLResultList results(1);
EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded());
auto result = results.getResults()[0];
EXPECT_EQ(
cast<IntegerAttr>(result.cast<Attribute>()).getValue().getSExtValue(),
-128);
}

// float
{
auto value = rewriter.getF32FloatAttr(-1.0);
TestPDLResultList results(1);
EXPECT_TRUE(builtin::abs(rewriter, results, {value}).succeeded());
auto result = results.getResults()[0];
EXPECT_EQ(
cast<FloatAttr>(result.cast<Attribute>()).getValue().convertToFloat(),
1.0);
}
}
} // namespace