diff --git a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp index 8deb8f028ba45..d0ae2b64ffa9e 100644 --- a/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/ExpandOps.cpp @@ -160,6 +160,23 @@ struct FloorDivSIOpConverter : public OpRewritePattern { } }; +template +struct MaxMinIOpConverter : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const final { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + Location loc = op.getLoc(); + Value cmp = rewriter.create(loc, pred, lhs, rhs); + rewriter.replaceOpWithNewOp(op, cmp, lhs, rhs); + return success(); + } +}; + template struct MaximumMinimumFOpConverter : public OpRewritePattern { public: @@ -344,6 +361,10 @@ struct ArithExpandOpsPass arith::CeilDivSIOp, arith::CeilDivUIOp, arith::FloorDivSIOp, + arith::MaxSIOp, + arith::MaxUIOp, + arith::MinSIOp, + arith::MinUIOp, arith::MaximumFOp, arith::MinimumFOp, arith::MaxNumFOp, @@ -392,6 +413,10 @@ void mlir::arith::populateArithExpandOpsPatterns(RewritePatternSet &patterns) { populateCeilFloorDivExpandOpsPatterns(patterns); // clang-format off patterns.add< + MaxMinIOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, + MaxMinIOpConverter, MaximumMinimumFOpConverter, MaximumMinimumFOpConverter, MaxNumMinNumFOpConverter, diff --git a/mlir/test/Dialect/Arith/expand-ops.mlir b/mlir/test/Dialect/Arith/expand-ops.mlir index 046e8ff64fba6..ec258a8138311 100644 --- a/mlir/test/Dialect/Arith/expand-ops.mlir +++ b/mlir/test/Dialect/Arith/expand-ops.mlir @@ -295,3 +295,51 @@ func.func @truncf_vector_f32(%arg0 : vector<4xf32>) -> vector<4xbf16> { // CHECK-LABEL: @truncf_vector_f32 // CHECK-NOT: arith.truncf + +// ----- + +func.func @maxsi(%a: i32, %b: i32) -> i32 { + %result = arith.maxsi %a, %b : i32 + return %result : i32 +} +// CHECK-LABEL: func @maxsi +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32 +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi sgt, %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RESULT]] : i32 + +// ----- + +func.func @minsi(%a: i32, %b: i32) -> i32 { + %result = arith.minsi %a, %b : i32 + return %result : i32 +} +// CHECK-LABEL: func @minsi +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32 +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi slt, %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RESULT]] : i32 + +// ----- + +func.func @maxui(%a: i32, %b: i32) -> i32 { + %result = arith.maxui %a, %b : i32 + return %result : i32 +} +// CHECK-LABEL: func @maxui +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32 +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ugt, %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RESULT]] : i32 + +// ----- + +func.func @minui(%a: i32, %b: i32) -> i32 { + %result = arith.minui %a, %b : i32 + return %result : i32 +} +// CHECK-LABEL: func @minui +// CHECK-SAME: %[[LHS:.*]]: i32, %[[RHS:.*]]: i32 +// CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: %[[RESULT:.*]] = arith.select %[[CMP]], %[[LHS]], %[[RHS]] : i32 +// CHECK-NEXT: return %[[RESULT]] : i32