From d57479cfbe9a6b4dccedfd1221c04973ad90ec97 Mon Sep 17 00:00:00 2001 From: Jerry-Ge Date: Wed, 19 Feb 2025 10:19:57 -0800 Subject: [PATCH] [mlir][tosa] Update SelectOp's input names to match TOSA specification (#127833) Updated: - pred to input1 - on_true to input2 - on_false to input3 Signed-off-by: Jerry Ge --- mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td | 8 ++++---- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 14 +++++++------- .../Tosa/Transforms/TosaMakeBroadcastable.cpp | 6 +++--- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td index 4d5837ca26c91..7cdf79f4dc59d 100644 --- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td +++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td @@ -1190,9 +1190,9 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { }]; let arguments = (ins - Tosa_I1Tensor:$pred, - Tosa_Tensor:$on_true, - Tosa_Tensor:$on_false + Tosa_I1Tensor:$input1, + Tosa_Tensor:$input2, + Tosa_Tensor:$input3 ); let results = (outs @@ -1202,7 +1202,7 @@ def Tosa_SelectOp : Tosa_ElementwiseOp<"select"> { let hasFolder = 1; let assemblyFormat = [{ - operands attr-dict `:` `(` type($pred) `,` type($on_true) `,` type($on_false) + operands attr-dict `:` `(` type($input1) `,` type($input2) `,` type($input3) `)` `->` type($output) }]; } diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index b9bcedb7fe71d..9bfc2aae1d6a5 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -65,12 +65,12 @@ void ConcatOp::getCanonicalizationPatterns(RewritePatternSet &results, } LogicalResult SelectOp::canonicalize(SelectOp op, PatternRewriter &rewriter) { - auto notOp = op.getPred().getDefiningOp(); + auto notOp = op.getInput1().getDefiningOp(); if (!notOp) return failure(); rewriter.modifyOpInPlace(op, [&]() { op.getOperation()->setOperands( - {notOp.getInput1(), op.getOnFalse(), op.getOnTrue()}); + {notOp.getInput1(), op.getInput3(), op.getInput2()}); }); return success(); } @@ -1131,18 +1131,18 @@ OpFoldResult SliceOp::fold(FoldAdaptor adaptor) { } OpFoldResult tosa::SelectOp::fold(FoldAdaptor adaptor) { - if (getOnTrue() == getOnFalse()) - return getOnTrue(); + if (getInput2() == getInput3()) + return getInput2(); auto predicate = - llvm::dyn_cast_if_present(adaptor.getPred()); + llvm::dyn_cast_if_present(adaptor.getInput1()); if (!predicate) return {}; if (!predicate.isSplat()) return {}; - return predicate.getSplatValue().getBoolValue() ? getOnTrue() - : getOnFalse(); + return predicate.getSplatValue().getBoolValue() ? getInput2() + : getInput3(); } OpFoldResult TileOp::fold(FoldAdaptor adaptor) { diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp index 79afc75fd6c8e..87b2a2695351b 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp @@ -169,9 +169,9 @@ struct ConvertTosaOp : public OpRewritePattern { LogicalResult matchAndRewrite(tosa::SelectOp tosaOp, PatternRewriter &rewriter) const override { - Value input1 = tosaOp.getPred(); - Value input2 = tosaOp.getOnTrue(); - Value input3 = tosaOp.getOnFalse(); + Value input1 = tosaOp.getInput1(); + Value input2 = tosaOp.getInput2(); + Value input3 = tosaOp.getInput3(); Value output = tosaOp.getResult(); auto outputType = dyn_cast(output.getType());