From a5766f55a102325acf4a161452d7383cdbe7658c Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Tue, 18 Jun 2024 16:47:46 +0100 Subject: [PATCH 1/2] Fix TOSA cast constant op folding --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index 468961bd10f6d..cfe6286bc5973 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -888,7 +888,7 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { llvm::cast(outETy).getIntOrFloatBitWidth(), unsign); auto floatVal = operand.getSplatValue(); bool exact; - floatVal.convertToInteger(intVal, llvm::RoundingMode::TowardZero, &exact); + floatVal.convertToInteger(intVal, llvm::RoundingMode::NearestTiesToEven, &exact); return SplatElementsAttr::get(outTy, intVal); } From b36c81e3cffb22792b8d09f7e3094cf528bdb979 Mon Sep 17 00:00:00 2001 From: Corentin Ferry Date: Wed, 19 Jun 2024 07:08:10 +0100 Subject: [PATCH 2/2] Add a test case --- mlir/test/Dialect/Tosa/constant-op-fold.mlir | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir index 410eab79f7fef..c8b098b658a42 100644 --- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir +++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir @@ -602,6 +602,17 @@ func.func @cast_float_to_int() -> tensor { // ----- +// CHECK: func.func @cast_float_to_int_round +func.func @cast_float_to_int_round() -> tensor { + %splat = "tosa.const"() {value = dense<-3.5> : tensor} : () -> tensor + // CHECK: %[[SPLAT:.+]] = "tosa.const"() <{value = dense<-4> : tensor} + %cast = tosa.cast %splat : (tensor) -> tensor + // CHECK: return %[[SPLAT]] + return %cast : tensor +} + +// ----- + // CHECK: func.func @cast_int_to_int_trunc func.func @cast_int_to_int_trunc() -> tensor { %splat = "tosa.const"() {value = dense<-1> : tensor} : () -> tensor