From b8011b43032927a135806e0aec4858696945d372 Mon Sep 17 00:00:00 2001 From: quic-samanara Date: Tue, 18 Nov 2025 14:42:18 -0800 Subject: [PATCH 1/5] Refine MinMaxConverter Address review comments to reuse code --- .../ConversionPatterns.hpp | 200 ++++++++++++++++-- 1 file changed, 179 insertions(+), 21 deletions(-) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 1d9244df..521c1573 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -1845,47 +1845,205 @@ template struct MinMaxConverter : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; + template static T getOnlyUserOfType(Value val) { + if (!val || !val.hasOneUse()) { + return nullptr; + } + return dyn_cast(*val.getUsers().begin()); + } + + // Only handle the Cmp + OrIOp + Select pattern here. + static arith::SelectOp findSelectThroughOr(Value cond) { + if (auto ori = getOnlyUserOfType(cond)) { + return getOnlyUserOfType(ori.getResult()); + } + return nullptr; + } + MinMaxConverter(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/10) {} + + /// Helper that maps a floating-point compare predicate to the + /// corresponding min/max operation. THis is parametrized by + /// whether we want NaN-aware operations (MaximumFOp/MinimumFOp) or + /// numeric operations (MaxNumFOp/MinNumFOp). + LogicalResult foldCmpToMinMax(PatternRewriter &rewriter, Location loc, + Value lhs, Value rhs, arith::CmpFPredicate pred, + bool useNaNOps, Value &result) const { + switch (pred) { + case arith::CmpFPredicate::OGT: + case arith::CmpFPredicate::OGE: + if (useNaNOps) { + result = rewriter.create(loc, lhs, rhs).getResult(); + } else { + result = rewriter.create(loc, lhs, rhs).getResult(); + } + return success(); + case arith::CmpFPredicate::OLT: + case arith::CmpFPredicate::OLE: + if (useNaNOps) { + result = rewriter.create(loc, lhs, rhs).getResult(); + } else { + result = rewriter.create(loc, lhs, rhs).getResult(); + } + return success(); + default: + return failure(); + } + } LogicalResult matchAndRewrite(CmpOp cmpOp, PatternRewriter &rewriter) const final { - if (!cmpOp.getResult().hasOneUse()) { + Value result = cmpOp.getResult(); + if (!result.hasOneUse()) { return failure(); } - auto selectOp = - dyn_cast(*cmpOp.getResult().getUsers().begin()); + + // 1. Simple pattern: cmpf + select. + if (auto selectOp = dyn_cast(*result.getUsers().begin())) { + if (!(result == selectOp.getCondition() && + (cmpOp.getLhs() == selectOp.getTrueValue() && + cmpOp.getRhs() == selectOp.getFalseValue()))) { + return failure(); + } + + rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate()); + rewriter.eraseOp(cmpOp); + return success(); + } + + // 2. NaN-aware pattern: cmpf + or + select. + auto selectOp = findSelectThroughOr(result); if (!selectOp) { return failure(); } - if (!(cmpOp.getResult() == selectOp.getCondition() && - cmpOp.getLhs() == selectOp.getTrueValue() && - cmpOp.getRhs() == selectOp.getFalseValue())) { + if (failed(foldCmpSelectToMinMax(rewriter, selectOp))) { + return failure(); + } + return success(); + } + + /// foldCmpSelectToMinMax performs an optimization pattern that matches + /// 'arith.select' operations based on a floating-point comparison + /// and rewrites them into equivalent numeric min/max operations. + /// + /// This pattern handles the following cases: + /// + /// 1. ** Simple Min/Max Reduction ** + /// - select (cmpf ogt a, b), a, b --> arith.maxnumf(a, b) + /// - select (cmpf olt a, b), a, b --> arith.minnumf(a, b) + /// + /// 2. ** NaN-Aware Min/Max Reduction ** + /// - select (cmpf ogt a, b) || cmpf une a, a), a, b --> arith.maximumf(a, + /// b) + /// - select (cmpf olt a, b) || cmpf une a, a), a, b --> arith.minimumf(a, + /// b) + /// + /// These transformations not only improve IR canonicalization but also + /// allow the successful lowering of tt.reduce operations to linalg + /// operations, which is already supported in the triton-shared dialect + /// conversion pipeline. + + LogicalResult foldCmpSelectToMinMax(PatternRewriter &rewriter, + arith::SelectOp sel) const { + + if (!isa(sel.getType())) { + return failure(); + } + + Operation *condOp = sel.getCondition().getDefiningOp(); + if (!condOp) { return failure(); } - rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate()); - rewriter.eraseOp(cmpOp); + Value trueVal = sel.getTrueValue(); + Value falseVal = sel.getFalseValue(); - return success(); + // Case 1: Simple Min/Max Reduction (numeric min/max). + if (auto cmp = dyn_cast(condOp)) { + + if (cmp.getLhs() == trueVal && cmp.getRhs() == falseVal) { + auto pred = cmp.getPredicate(); + // Only fold OGT/OLT predicates. + if (pred == arith::CmpFPredicate::OGT || + pred == arith::CmpFPredicate::OLT) { + rewriter.setInsertionPoint(sel); + Value foldedResult; + if (failed(foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, + pred, /*useNaNOps=*/false, + foldedResult))) { + return failure(); + } + rewriter.replaceOp(sel, foldedResult); + return success(); + } + } + } + + // Case 2: NaN-Aware Min/Max Reduction. + if (auto ori = dyn_cast(condOp)) { + // Extract both sides of the OR condition. + auto cmp1 = ori.getLhs().getDefiningOp(); + auto cmp2 = ori.getRhs().getDefiningOp(); + if (!cmp1 || !cmp2) + return failure(); + + // Helper lambdas to identify comparison patterns. + auto isOGT = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::OGT && + trueVal == cmp.getLhs() && falseVal == cmp.getRhs(); + }; + auto isOLT = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::OLT && + trueVal == cmp.getLhs() && falseVal == cmp.getRhs(); + }; + auto isNaN = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::UNE && + trueVal == cmp.getLhs() && trueVal == cmp.getRhs(); + }; + + // Match: select ((ogt(a, b) || une(a, a)), a, b) -> arith.maximumf(a, b). + if ((isOGT(cmp1) && isNaN(cmp2)) || (isOGT(cmp2) && isNaN(cmp1))) { + rewriter.setInsertionPoint(sel); + Value foldedResult; + if (failed(foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, + arith::CmpFPredicate::OGT, + /*useNaNOps=*/true, foldedResult))) { + return failure(); + } + rewriter.replaceOp(sel, foldedResult); + return success(); + } + + // Match: select ((olt(a, b) || une(a, a)), a, b) -> arith.minimumf(a, b). + if ((isOLT(cmp1) && isNaN(cmp2)) || (isOLT(cmp2) && isNaN(cmp1))) { + rewriter.setInsertionPoint(sel); + Value foldedResult; + if (failed(foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, + arith::CmpFPredicate::OLT, + /*useNaNOps=*/true, foldedResult))) { + return failure(); + } + rewriter.replaceOp(sel, foldedResult); + return success(); + } + } + return failure(); } void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpFOp cmpOp, arith::SelectOp selectOp, arith::CmpFPredicate pred) const { - switch (pred) { - case arith::CmpFPredicate::OGT: - case arith::CmpFPredicate::OGE: - rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), - cmpOp.getRhs()); - break; - case arith::CmpFPredicate::OLT: - case arith::CmpFPredicate::OLE: - rewriter.replaceOpWithNewOp(selectOp, cmpOp.getLhs(), - cmpOp.getRhs()); - break; - default: + Value foldedResult; + // For the generic cmp+select pattern, we use the NaN-aware ops + // (arith::MaximumFOp/arith::MinimumFOp) to preserve semantics in the + // presence of NaN values. + if (succeeded(foldCmpToMinMax(rewriter, selectOp.getLoc(), cmpOp.getLhs(), + cmpOp.getRhs(), pred, /*useNaNOps=*/true, + foldedResult))) { + rewriter.replaceOp(selectOp, foldedResult); + } else { llvm_unreachable("Unhandled predicate"); } } From bee6347771aa66213dc1dd09ea878a270b90b10b Mon Sep 17 00:00:00 2001 From: quic-samanara Date: Tue, 25 Nov 2025 12:25:09 -0800 Subject: [PATCH 2/5] Use Failure insteaf of passing in reference to output parameter. --- .../ConversionPatterns.hpp | 63 +++++++++---------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 521c1573..0d80a894 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -1862,31 +1862,31 @@ struct MinMaxConverter : public OpRewritePattern { MinMaxConverter(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/10) {} - + /// Helper that maps a floating-point compare predicate to the /// corresponding min/max operation. THis is parametrized by /// whether we want NaN-aware operations (MaximumFOp/MinimumFOp) or /// numeric operations (MaxNumFOp/MinNumFOp). - LogicalResult foldCmpToMinMax(PatternRewriter &rewriter, Location loc, - Value lhs, Value rhs, arith::CmpFPredicate pred, - bool useNaNOps, Value &result) const { + FailureOr foldCmpToMinMax(PatternRewriter &rewriter, Location loc, + Value lhs, Value rhs, + arith::CmpFPredicate pred, + bool useNaNOps) const { switch (pred) { case arith::CmpFPredicate::OGT: case arith::CmpFPredicate::OGE: if (useNaNOps) { - result = rewriter.create(loc, lhs, rhs).getResult(); + return rewriter.create(loc, lhs, rhs).getResult(); } else { - result = rewriter.create(loc, lhs, rhs).getResult(); + return rewriter.create(loc, lhs, rhs).getResult(); } return success(); case arith::CmpFPredicate::OLT: case arith::CmpFPredicate::OLE: if (useNaNOps) { - result = rewriter.create(loc, lhs, rhs).getResult(); + return rewriter.create(loc, lhs, rhs).getResult(); } else { - result = rewriter.create(loc, lhs, rhs).getResult(); + return rewriter.create(loc, lhs, rhs).getResult(); } - return success(); default: return failure(); } @@ -1970,12 +1970,13 @@ struct MinMaxConverter : public OpRewritePattern { pred == arith::CmpFPredicate::OLT) { rewriter.setInsertionPoint(sel); Value foldedResult; - if (failed(foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, - pred, /*useNaNOps=*/false, - foldedResult))) { + FailureOr foldResult = + foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, pred, + /*useNaNOps=*/false); + if (failed(foldResult)) { return failure(); } - rewriter.replaceOp(sel, foldedResult); + rewriter.replaceOp(sel, *foldResult); return success(); } } @@ -2006,26 +2007,28 @@ struct MinMaxConverter : public OpRewritePattern { // Match: select ((ogt(a, b) || une(a, a)), a, b) -> arith.maximumf(a, b). if ((isOGT(cmp1) && isNaN(cmp2)) || (isOGT(cmp2) && isNaN(cmp1))) { rewriter.setInsertionPoint(sel); - Value foldedResult; - if (failed(foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, - arith::CmpFPredicate::OGT, - /*useNaNOps=*/true, foldedResult))) { + FailureOr foldResult = + foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, + arith::CmpFPredicate::OGT, + /*useNaNOps=*/true); + if (failed(foldResult)) { return failure(); } - rewriter.replaceOp(sel, foldedResult); + rewriter.replaceOp(sel, *foldResult); return success(); } // Match: select ((olt(a, b) || une(a, a)), a, b) -> arith.minimumf(a, b). if ((isOLT(cmp1) && isNaN(cmp2)) || (isOLT(cmp2) && isNaN(cmp1))) { rewriter.setInsertionPoint(sel); - Value foldedResult; - if (failed(foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, - arith::CmpFPredicate::OLT, - /*useNaNOps=*/true, foldedResult))) { + FailureOr foldResult = + foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, + arith::CmpFPredicate::OLT, + /*useNaNOps=*/true); + if (failed(foldResult)) { return failure(); } - rewriter.replaceOp(sel, foldedResult); + rewriter.replaceOp(sel, *foldResult); return success(); } } @@ -2035,17 +2038,13 @@ struct MinMaxConverter : public OpRewritePattern { void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpFOp cmpOp, arith::SelectOp selectOp, arith::CmpFPredicate pred) const { - Value foldedResult; - // For the generic cmp+select pattern, we use the NaN-aware ops - // (arith::MaximumFOp/arith::MinimumFOp) to preserve semantics in the - // presence of NaN values. - if (succeeded(foldCmpToMinMax(rewriter, selectOp.getLoc(), cmpOp.getLhs(), - cmpOp.getRhs(), pred, /*useNaNOps=*/true, - foldedResult))) { - rewriter.replaceOp(selectOp, foldedResult); - } else { + FailureOr foldedResult = + foldCmpToMinMax(rewriter, selectOp.getLoc(), cmpOp.getLhs(), + cmpOp.getRhs(), pred, /*useNaNOps=*/true); + if (failed(foldedResult)) { llvm_unreachable("Unhandled predicate"); } + rewriter.replaceOp(selectOp, *foldedResult); } void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpIOp cmpOp, From 9b59bb57a0cc4b0cc98c2f92b552cc3adc2c6b3a Mon Sep 17 00:00:00 2001 From: quic-samanara Date: Tue, 25 Nov 2025 12:28:51 -0800 Subject: [PATCH 3/5] Add insertion guards to ensure that the insertion point is restored after pattern rewrites. --- .../Conversion/TritonArithToLinalg/ConversionPatterns.hpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 0d80a894..4e70bbc2 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -1968,6 +1968,7 @@ struct MinMaxConverter : public OpRewritePattern { // Only fold OGT/OLT predicates. if (pred == arith::CmpFPredicate::OGT || pred == arith::CmpFPredicate::OLT) { + PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sel); Value foldedResult; FailureOr foldResult = @@ -2006,6 +2007,7 @@ struct MinMaxConverter : public OpRewritePattern { // Match: select ((ogt(a, b) || une(a, a)), a, b) -> arith.maximumf(a, b). if ((isOGT(cmp1) && isNaN(cmp2)) || (isOGT(cmp2) && isNaN(cmp1))) { + PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sel); FailureOr foldResult = foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, @@ -2020,6 +2022,7 @@ struct MinMaxConverter : public OpRewritePattern { // Match: select ((olt(a, b) || une(a, a)), a, b) -> arith.minimumf(a, b). if ((isOLT(cmp1) && isNaN(cmp2)) || (isOLT(cmp2) && isNaN(cmp1))) { + PatternRewriter::InsertionGuard guard(rewriter); rewriter.setInsertionPoint(sel); FailureOr foldResult = foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, From 713c8364253b399d52c236b1d39f94f7e0db7232 Mon Sep 17 00:00:00 2001 From: quic-samanara Date: Tue, 2 Dec 2025 17:05:30 -0800 Subject: [PATCH 4/5] Simplify control flow in foldCmpSelectToMinMax pattern --- .../ConversionPatterns.hpp | 122 +++++++----------- 1 file changed, 47 insertions(+), 75 deletions(-) diff --git a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp index 4e70bbc2..32e279b4 100644 --- a/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp +++ b/include/triton-shared/Conversion/TritonArithToLinalg/ConversionPatterns.hpp @@ -1928,13 +1928,9 @@ struct MinMaxConverter : public OpRewritePattern { /// 'arith.select' operations based on a floating-point comparison /// and rewrites them into equivalent numeric min/max operations. /// - /// This pattern handles the following cases: + /// This pattern handles the following case: /// - /// 1. ** Simple Min/Max Reduction ** - /// - select (cmpf ogt a, b), a, b --> arith.maxnumf(a, b) - /// - select (cmpf olt a, b), a, b --> arith.minnumf(a, b) - /// - /// 2. ** NaN-Aware Min/Max Reduction ** + /// ** NaN-Aware Min/Max Reduction ** /// - select (cmpf ogt a, b) || cmpf une a, a), a, b --> arith.maximumf(a, /// b) /// - select (cmpf olt a, b) || cmpf une a, a), a, b --> arith.minimumf(a, @@ -1960,80 +1956,56 @@ struct MinMaxConverter : public OpRewritePattern { Value trueVal = sel.getTrueValue(); Value falseVal = sel.getFalseValue(); - // Case 1: Simple Min/Max Reduction (numeric min/max). - if (auto cmp = dyn_cast(condOp)) { - - if (cmp.getLhs() == trueVal && cmp.getRhs() == falseVal) { - auto pred = cmp.getPredicate(); - // Only fold OGT/OLT predicates. - if (pred == arith::CmpFPredicate::OGT || - pred == arith::CmpFPredicate::OLT) { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(sel); - Value foldedResult; - FailureOr foldResult = - foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, pred, - /*useNaNOps=*/false); - if (failed(foldResult)) { - return failure(); - } - rewriter.replaceOp(sel, *foldResult); - return success(); - } - } - } + // NaN-Aware Min/Max Reduction. + auto ori = dyn_cast(condOp); + if (!ori) + return failure(); + // Extract both sides of the OR condition. + auto cmp1 = ori.getLhs().getDefiningOp(); + auto cmp2 = ori.getRhs().getDefiningOp(); + if (!cmp1 || !cmp2) + return failure(); - // Case 2: NaN-Aware Min/Max Reduction. - if (auto ori = dyn_cast(condOp)) { - // Extract both sides of the OR condition. - auto cmp1 = ori.getLhs().getDefiningOp(); - auto cmp2 = ori.getRhs().getDefiningOp(); - if (!cmp1 || !cmp2) + // Helper lambdas to identify comparison patterns. + auto isOGT = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::OGT && + trueVal == cmp.getLhs() && falseVal == cmp.getRhs(); + }; + auto isOLT = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::OLT && + trueVal == cmp.getLhs() && falseVal == cmp.getRhs(); + }; + auto isNaN = [&](arith::CmpFOp cmp) { + return cmp.getPredicate() == arith::CmpFPredicate::UNE && + trueVal == cmp.getLhs() && trueVal == cmp.getRhs(); + }; + + // Match: select ((ogt(a, b) || une(a, a)), a, b) -> arith.maximumf(a, b). + if ((isOGT(cmp1) && isNaN(cmp2)) || (isOGT(cmp2) && isNaN(cmp1))) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sel); + FailureOr foldResult = foldCmpToMinMax( + rewriter, sel.getLoc(), trueVal, falseVal, arith::CmpFPredicate::OGT, + /*useNaNOps=*/true); + if (failed(foldResult)) { return failure(); - - // Helper lambdas to identify comparison patterns. - auto isOGT = [&](arith::CmpFOp cmp) { - return cmp.getPredicate() == arith::CmpFPredicate::OGT && - trueVal == cmp.getLhs() && falseVal == cmp.getRhs(); - }; - auto isOLT = [&](arith::CmpFOp cmp) { - return cmp.getPredicate() == arith::CmpFPredicate::OLT && - trueVal == cmp.getLhs() && falseVal == cmp.getRhs(); - }; - auto isNaN = [&](arith::CmpFOp cmp) { - return cmp.getPredicate() == arith::CmpFPredicate::UNE && - trueVal == cmp.getLhs() && trueVal == cmp.getRhs(); - }; - - // Match: select ((ogt(a, b) || une(a, a)), a, b) -> arith.maximumf(a, b). - if ((isOGT(cmp1) && isNaN(cmp2)) || (isOGT(cmp2) && isNaN(cmp1))) { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(sel); - FailureOr foldResult = - foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, - arith::CmpFPredicate::OGT, - /*useNaNOps=*/true); - if (failed(foldResult)) { - return failure(); - } - rewriter.replaceOp(sel, *foldResult); - return success(); } + rewriter.replaceOp(sel, *foldResult); + return success(); + } - // Match: select ((olt(a, b) || une(a, a)), a, b) -> arith.minimumf(a, b). - if ((isOLT(cmp1) && isNaN(cmp2)) || (isOLT(cmp2) && isNaN(cmp1))) { - PatternRewriter::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(sel); - FailureOr foldResult = - foldCmpToMinMax(rewriter, sel.getLoc(), trueVal, falseVal, - arith::CmpFPredicate::OLT, - /*useNaNOps=*/true); - if (failed(foldResult)) { - return failure(); - } - rewriter.replaceOp(sel, *foldResult); - return success(); + // Match: select ((olt(a, b) || une(a, a)), a, b) -> arith.minimumf(a, b). + if ((isOLT(cmp1) && isNaN(cmp2)) || (isOLT(cmp2) && isNaN(cmp1))) { + PatternRewriter::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(sel); + FailureOr foldResult = foldCmpToMinMax( + rewriter, sel.getLoc(), trueVal, falseVal, arith::CmpFPredicate::OLT, + /*useNaNOps=*/true); + if (failed(foldResult)) { + return failure(); } + rewriter.replaceOp(sel, *foldResult); + return success(); } return failure(); } From f1efef1ad1f38f78bbbb173f4ff6bb74bd0713cf Mon Sep 17 00:00:00 2001 From: quic-samanara Date: Tue, 2 Dec 2025 18:05:04 -0800 Subject: [PATCH 5/5] Add LIT test --- .../TritonToLinalg/convert_minmax_reduce.mlir | 34 ++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir b/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir index eaf30630..c02a2a81 100644 --- a/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir +++ b/test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir @@ -123,4 +123,36 @@ module { // CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>> // CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>> // CHECK: return -// CHECK: } \ No newline at end of file +// CHECK: } + +// ----- + +module { + tt.func public @nan_aware_max(%arg0: tensor<1024xf32>, %arg_out: !tt.ptr) { + %res = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({ + ^bb0(%lhs: f32, %rhs: f32): + %cmp_gt = arith.cmpf ogt, %lhs, %rhs : f32 + %lhs_nan = arith.cmpf une, %lhs, %lhs : f32 + %pred = arith.ori %cmp_gt, %lhs_nan : i1 + %sel = arith.select %pred, %lhs, %rhs : f32 + tt.reduce.return %sel : f32 + }) : (tensor<1024xf32>) -> f32 + tt.store %arg_out, %res : !tt.ptr + tt.return +} +} + +// CHECK-LABEL: func.func @nan_aware_max +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1024xf32>, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) { +// CHECK-DAG: [[CST_nan_:%.+]] = arith.constant 0xFF800000 : f32 +// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.alloc_tensor() : tensor +// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_nan_]] into [[VAR_0_]][] : tensor +// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[PARAM_0_]] : tensor<1024xf32>) outs([[VAR_inserted_]] : tensor) dimensions = [0] +// CHECK: ([[in_:%.+]]: f32, [[in_]]it: f32) { +// CHECK: [[CMP_gt_:%.+]] = arith.maximumf [[in_]], [[in_]]it : f32 +// CHECK: linalg.yield [[CMP_gt_]] : f32 +// CHECK: } +// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor +// CHECK: tt.store [[PARAM_1_]], [[VAR_extracted_]] : !tt.ptr +// CHECK: return +// CHECK: }