From 44b6eaec167518fd835c8433b1fad22f3e458d01 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 29 Apr 2024 13:56:29 +0200 Subject: [PATCH 1/2] expand-ops: minsi/maxsi (and unsigned) --- .../Dialect/Arith/Transforms/ExpandOps.cpp | 25 ++++++++++ mlir/test/Dialect/Arith/expand-ops.mlir | 48 +++++++++++++++++++ 2 files changed, 73 insertions(+) diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 8deb8f028ba45..d0ae2b64ffa9e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -160,6 +160,23 @@ struct FloorDivSIOpConverter : public OpRewritePattern { } }; +template +struct MaxMinIOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const final { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + Location loc = op.getLoc(); + Value cmp = rewriter.create(loc, pred, lhs, rhs); + rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); + return success(); + } +}; + template struct MaximumMinimumFOpConverter : public OpRewritePattern { public: @@ -344,6 +361,10 @@ struct ArithExpandOpsPass arith::CeilDivSIOp, arith::CeilDivUIOp, arith::FloorDivSIOp, + arith::MaxSIOp, + arith::MaxUIOp, + arith::MinSIOp, + arith::MinUIOp, arith::MaximumFOp, arith::MinimumFOp, arith::MaxNumFOp, @@ -392,6 +413,10 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); // clang-format off patterns.add< + MaxMinIOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, MaximumMinimumFOpConverter, MaximumMinimumFOpConverter, MaxNumMinNumFOpConverter, diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index 046e8ff64fba6..a08edf47ee084 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -295,3 +295,51 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> { // CHECK-LABEL: @truncf_vector_f32 // CHECK-NOT: arith.truncf + +// ----- + +func.func @maxsi(%a: i32, %b: i32) -> i32 { + %result = arith.maxsi %a, %b : i32 + return %result : i32 +} +// CHECK-LABEL: func @maxsi +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RESULT]] : i32 + +// ----- + +func.func @minsi(%a: i32, %b: i32) -> i32 { + %result = arith.minsi %a, %b : i32 + return %result : i32 +} +// CHECK-LABEL: func @minsi +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RESULT]] : i32 + +// ----- + +func.func @maxui(%a: i32, %b: i32) -> i32 { + %result = arith.maxui %a, %b : i32 + return %result : i32 +} +// CHECK-LABEL: func @maxui +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RESULT]] : i32 + +// ----- + +func.func @minui(%a: i32, %b: i32) -> i32 { + %result = arith.minui %a, %b : i32 + return %result : i32 +} +// CHECK-LABEL: func @minui +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RESULT]] : i32 From e82a97762587350ddd3c579cfac7b8d105d41658 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 29 Apr 2024 15:52:58 +0200 Subject: [PATCH 2/2] Update tests --- mlir/test/Dialect/Arith/expand-ops.mlir | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index a08edf47ee084..ec258a8138311 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -303,7 +303,7 @@ func.func @maxsi(%a: i32, %b: i32) -> i32 { return %result : i32 } // CHECK-LABEL: func @maxsi -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32 // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: return %[[RESULT]] : i32 @@ -315,7 +315,7 @@ func.func @minsi(%a: i32, %b: i32) -> i32 { return %result : i32 } // CHECK-LABEL: func @minsi -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32 // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: return %[[RESULT]] : i32 @@ -327,7 +327,7 @@ func.func @maxui(%a: i32, %b: i32) -> i32 { return %result : i32 } // CHECK-LABEL: func @maxui -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32 // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: return %[[RESULT]] : i32 @@ -339,7 +339,7 @@ func.func @minui(%a: i32, %b: i32) -> i32 { return %result : i32 } // CHECK-LABEL: func @minui -// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32) +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32 // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 // CHECK-NEXT: return %[[RESULT]] : i32