From 07c7f9ca6e6c4d5b7a7de04d580c4c2e00901728 Mon Sep 17 00:00:00 2001
From: Christopher McGirr <christopher.mcgirr@amd.com>
Date: Tue, 21 May 2024 10:51:17 +0100
Subject: [PATCH] fix(TosaCanonicalize): create FusedLoc when clamp folding

---
 mlir/include/mlir/IR/PatternMatch.h                | 14 ++++++++++++--
 mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp |  6 +++---
 .../Dialect/Tosa/canonicalize_with_debuginfo.mlir  | 14 ++++++++++++++
 3 files changed, 29 insertions(+), 5 deletions(-)
 create mode 100644 mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 2ce3bc3fc2e78..051dd0e62cc53 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -532,14 +532,24 @@ class RewriterBase : public OpBuilder {
   virtual void replaceOp(Operation *op, Operation *newOp);
 
   /// Replaces the result op with a new op that is created without verification.
+  /// Use a given list of locations to generate a FusedLoc for the new op.
   /// The result values of the two ops must be the same types.
   template <typename OpTy, typename... Args>
-  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
-    auto newOp = create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
+  OpTy replaceOpWithNewOp(Operation *op, ArrayRef<Location> locs,
+                          Args &&...args) {
+    auto newOp = create<OpTy>(getFusedLoc(locs), std::forward<Args>(args)...);
     replaceOp(op, newOp.getOperation());
     return newOp;
   }
 
+  /// Replaces the result op with a new op that is created without verification.
+  /// The result values of the two ops must be the same types.
+  template <typename OpTy, typename... Args>
+  OpTy replaceOpWithNewOp(Operation *op, Args &&...args) {
+    return replaceOpWithNewOp<OpTy, Args...>(op, {op->getLoc()},
+                                             std::forward<Args>(args)...);
+  }
+
   /// This method erases an operation that is known to have no uses.
   virtual void eraseOp(Operation *op);
 
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 6af6bfdd3ddac..468961bd10f6d 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -404,9 +404,9 @@ struct ClampClampOptimization : public OpRewritePattern<tosa::ClampOp> {
       auto minInt = std::max(op.getMinInt(), clampOp.getMinInt());
       auto maxInt = std::min(op.getMaxInt(), clampOp.getMaxInt());
 
-      rewriter.replaceOpWithNewOp<tosa::ClampOp>(
-          op, op.getType(), clampOp.getInput(),
-          rewriter.getI64IntegerAttr(minInt),
+      rewriter.replaceOpWithNewOp<ClampOp>(
+          op, {op->getLoc(), clampOp->getLoc()}, op.getType(),
+          clampOp.getInput(), rewriter.getI64IntegerAttr(minInt),
           rewriter.getI64IntegerAttr(maxInt), rewriter.getF32FloatAttr(minFp),
           rewriter.getF32FloatAttr(maxFp));
       return success();
diff --git a/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir b/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir
new file mode 100644
index 0000000000000..2d646a0a150a3
--- /dev/null
+++ b/mlir/test/Dialect/Tosa/canonicalize_with_debuginfo.mlir
@@ -0,0 +1,14 @@
+// RUN: mlir-opt -mlir-print-debuginfo -canonicalize="test-convergence" %s | FileCheck %s
+
+// CHECK-LABEL: @clamp_twice_is_single_clamp
+func.func @clamp_twice_is_single_clamp(%arg0: tensor<4xi8>) -> tensor<4xi8> {
+  // CHECK: tosa.clamp %arg0 {max_fp = 3.000000e+00 : f32, max_int = 2 : i64, min_fp = -3.000000e+00 : f32, min_int = -2 : i64} {{.*}} loc(#[[FUSED:.*]])
+  // CHECK-DAG: #[[A:.*]] = loc("Clamp_A")
+  // CHECK-DAG: #[[B:.*]] = loc("Clamp_B")
+  // CHECK:     #[[FUSED]] = loc(fused[#[[B]], #[[A]]])
+  %0 = tosa.clamp %arg0 {max_fp = 3.0 : f32, max_int = 4 : i64, min_fp = -5.0 : f32, min_int = -2 : i64} :  (tensor<4xi8>) -> tensor<4xi8> loc(#loc0)
+  %1 = tosa.clamp %0 {max_fp = 5.0 : f32, max_int = 2 : i64, min_fp = -3.0 : f32, min_int = -4 : i64} :  (tensor<4xi8>) -> tensor<4xi8> loc(#loc1)
+  return %1 : tensor<4xi8>
+}
+#loc0 = loc("Clamp_A")
+#loc1 = loc("Clamp_B")
\ No newline at end of file