Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1845,49 +1845,181 @@ template <typename CmpOp>
struct MinMaxConverter : public OpRewritePattern<CmpOp> {
using OpRewritePattern<CmpOp>::OpRewritePattern;

template <typename T> static T getOnlyUserOfType(Value val) {
if (!val || !val.hasOneUse()) {
return nullptr;
}
return dyn_cast<T>(*val.getUsers().begin());
}

// Only handle the Cmp + OrIOp + Select pattern here.
static arith::SelectOp findSelectThroughOr(Value cond) {
if (auto ori = getOnlyUserOfType<arith::OrIOp>(cond)) {
return getOnlyUserOfType<arith::SelectOp>(ori.getResult());
}
return nullptr;
}

MinMaxConverter(MLIRContext *context)
: OpRewritePattern<CmpOp>(context, /*benefit=*/10) {}

/// Helper that maps a floating-point compare predicate to the
/// corresponding min/max operation. THis is parametrized by
/// whether we want NaN-aware operations (MaximumFOp/MinimumFOp) or
/// numeric operations (MaxNumFOp/MinNumFOp).
FailureOr<Value> foldCmpToMinMax(PatternRewriter &rewriter, Location loc,
Value lhs, Value rhs,
arith::CmpFPredicate pred,
bool useNaNOps) const {
switch (pred) {
case arith::CmpFPredicate::OGT:
case arith::CmpFPredicate::OGE:
if (useNaNOps) {
return rewriter.create<arith::MaximumFOp>(loc, lhs, rhs).getResult();
} else {
return rewriter.create<arith::MaxNumFOp>(loc, lhs, rhs).getResult();
}
return success();
case arith::CmpFPredicate::OLT:
case arith::CmpFPredicate::OLE:
if (useNaNOps) {
return rewriter.create<arith::MinimumFOp>(loc, lhs, rhs).getResult();
} else {
return rewriter.create<arith::MinNumFOp>(loc, lhs, rhs).getResult();
}
default:
return failure();
}
}

LogicalResult matchAndRewrite(CmpOp cmpOp,
PatternRewriter &rewriter) const final {
if (!cmpOp.getResult().hasOneUse()) {
Value result = cmpOp.getResult();
if (!result.hasOneUse()) {
return failure();
}
auto selectOp =
dyn_cast<arith::SelectOp>(*cmpOp.getResult().getUsers().begin());

// 1. Simple pattern: cmpf + select.
if (auto selectOp = dyn_cast<arith::SelectOp>(*result.getUsers().begin())) {
if (!(result == selectOp.getCondition() &&
(cmpOp.getLhs() == selectOp.getTrueValue() &&
cmpOp.getRhs() == selectOp.getFalseValue()))) {
return failure();
}

rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate());
rewriter.eraseOp(cmpOp);
return success();
}

// 2. NaN-aware pattern: cmpf + or + select.
auto selectOp = findSelectThroughOr(result);
if (!selectOp) {
return failure();
}

if (!(cmpOp.getResult() == selectOp.getCondition() &&
cmpOp.getLhs() == selectOp.getTrueValue() &&
cmpOp.getRhs() == selectOp.getFalseValue())) {
if (failed(foldCmpSelectToMinMax(rewriter, selectOp))) {
return failure();
}
return success();
}

rewriteOpWithMinMax(rewriter, cmpOp, selectOp, cmpOp.getPredicate());
rewriter.eraseOp(cmpOp);
/// foldCmpSelectToMinMax performs an optimization pattern that matches
/// 'arith.select' operations based on a floating-point comparison
/// and rewrites them into equivalent numeric min/max operations.
///
/// This pattern handles the following case:
///
/// ** NaN-Aware Min/Max Reduction **
/// - select (cmpf ogt a, b) || cmpf une a, a), a, b --> arith.maximumf(a,
/// b)
/// - select (cmpf olt a, b) || cmpf une a, a), a, b --> arith.minimumf(a,
/// b)
///
/// These transformations not only improve IR canonicalization but also
/// allow the successful lowering of tt.reduce operations to linalg
/// operations, which is already supported in the triton-shared dialect
/// conversion pipeline.

LogicalResult foldCmpSelectToMinMax(PatternRewriter &rewriter,
arith::SelectOp sel) const {

if (!isa<FloatType>(sel.getType())) {
return failure();
}

return success();
Operation *condOp = sel.getCondition().getDefiningOp();
if (!condOp) {
return failure();
}

Value trueVal = sel.getTrueValue();
Value falseVal = sel.getFalseValue();

// NaN-Aware Min/Max Reduction.
auto ori = dyn_cast<arith::OrIOp>(condOp);
if (!ori)
return failure();
// Extract both sides of the OR condition.
auto cmp1 = ori.getLhs().getDefiningOp<arith::CmpFOp>();
auto cmp2 = ori.getRhs().getDefiningOp<arith::CmpFOp>();
if (!cmp1 || !cmp2)
return failure();

// Helper lambdas to identify comparison patterns.
auto isOGT = [&](arith::CmpFOp cmp) {
return cmp.getPredicate() == arith::CmpFPredicate::OGT &&
trueVal == cmp.getLhs() && falseVal == cmp.getRhs();
};
auto isOLT = [&](arith::CmpFOp cmp) {
return cmp.getPredicate() == arith::CmpFPredicate::OLT &&
trueVal == cmp.getLhs() && falseVal == cmp.getRhs();
};
auto isNaN = [&](arith::CmpFOp cmp) {
return cmp.getPredicate() == arith::CmpFPredicate::UNE &&
trueVal == cmp.getLhs() && trueVal == cmp.getRhs();
};

// Match: select ((ogt(a, b) || une(a, a)), a, b) -> arith.maximumf(a, b).
if ((isOGT(cmp1) && isNaN(cmp2)) || (isOGT(cmp2) && isNaN(cmp1))) {
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sel);
FailureOr<Value> foldResult = foldCmpToMinMax(
rewriter, sel.getLoc(), trueVal, falseVal, arith::CmpFPredicate::OGT,
/*useNaNOps=*/true);
if (failed(foldResult)) {
return failure();
}
rewriter.replaceOp(sel, *foldResult);
return success();
}

// Match: select ((olt(a, b) || une(a, a)), a, b) -> arith.minimumf(a, b).
if ((isOLT(cmp1) && isNaN(cmp2)) || (isOLT(cmp2) && isNaN(cmp1))) {
PatternRewriter::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sel);
FailureOr<Value> foldResult = foldCmpToMinMax(
rewriter, sel.getLoc(), trueVal, falseVal, arith::CmpFPredicate::OLT,
/*useNaNOps=*/true);
if (failed(foldResult)) {
return failure();
}
rewriter.replaceOp(sel, *foldResult);
return success();
}
return failure();
}

void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpFOp cmpOp,
arith::SelectOp selectOp,
arith::CmpFPredicate pred) const {
switch (pred) {
case arith::CmpFPredicate::OGT:
case arith::CmpFPredicate::OGE:
rewriter.replaceOpWithNewOp<arith::MaximumFOp>(selectOp, cmpOp.getLhs(),
cmpOp.getRhs());
break;
case arith::CmpFPredicate::OLT:
case arith::CmpFPredicate::OLE:
rewriter.replaceOpWithNewOp<arith::MinimumFOp>(selectOp, cmpOp.getLhs(),
cmpOp.getRhs());
break;
default:
FailureOr<Value> foldedResult =
foldCmpToMinMax(rewriter, selectOp.getLoc(), cmpOp.getLhs(),
cmpOp.getRhs(), pred, /*useNaNOps=*/true);
if (failed(foldedResult)) {
llvm_unreachable("Unhandled predicate");
}
rewriter.replaceOp(selectOp, *foldedResult);
}

void rewriteOpWithMinMax(PatternRewriter &rewriter, arith::CmpIOp cmpOp,
Expand Down
34 changes: 33 additions & 1 deletion test/Conversion/TritonToLinalg/convert_minmax_reduce.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,36 @@ module {
// CHECK: %[[VAL_13:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [1], strides: [1] : memref<*xi32> to memref<1xi32, strided<[1]>>
// CHECK: affine.store %[[VAL_12]], %[[VAL_13]][0] : memref<1xi32, strided<[1]>>
// CHECK: return
// CHECK: }
// CHECK: }

// -----

module {
tt.func public @nan_aware_max(%arg0: tensor<1024xf32>, %arg_out: !tt.ptr<f32>) {
%res = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
^bb0(%lhs: f32, %rhs: f32):
%cmp_gt = arith.cmpf ogt, %lhs, %rhs : f32
%lhs_nan = arith.cmpf une, %lhs, %lhs : f32
%pred = arith.ori %cmp_gt, %lhs_nan : i1
%sel = arith.select %pred, %lhs, %rhs : f32
tt.reduce.return %sel : f32
}) : (tensor<1024xf32>) -> f32
tt.store %arg_out, %res : !tt.ptr<f32>
tt.return
}
}

// CHECK-LABEL: func.func @nan_aware_max
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1024xf32>, [[PARAM_1_:%.+]]: !tt.ptr<f32>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32, [[PARAM_5_:%.+]]: i32, [[PARAM_6_:%.+]]: i32, [[PARAM_7_:%.+]]: i32) {
// CHECK-DAG: [[CST_nan_:%.+]] = arith.constant 0xFF800000 : f32
// CHECK-DAG: [[VAR_0_:%.+]] = bufferization.alloc_tensor() : tensor<f32>
// CHECK: [[VAR_inserted_:%.+]] = tensor.insert [[CST_nan_]] into [[VAR_0_]][] : tensor<f32>
// CHECK: [[VAR_reduced_:%.+]] = linalg.reduce ins([[PARAM_0_]] : tensor<1024xf32>) outs([[VAR_inserted_]] : tensor<f32>) dimensions = [0]
// CHECK: ([[in_:%.+]]: f32, [[in_]]it: f32) {
// CHECK: [[CMP_gt_:%.+]] = arith.maximumf [[in_]], [[in_]]it : f32
// CHECK: linalg.yield [[CMP_gt_]] : f32
// CHECK: }
// CHECK: [[VAR_extracted_:%.+]] = tensor.extract [[VAR_reduced_]][] : tensor<f32>
// CHECK: tt.store [[PARAM_1_]], [[VAR_extracted_]] : !tt.ptr<f32>
// CHECK: return
// CHECK: }