From f77637f46012c67934b09558de951c96d6368ee3 Mon Sep 17 00:00:00 2001
From: Matthias Gehre <matthias.gehre@amd.com>
Date: Wed, 12 Jun 2024 11:37:23 +0200
Subject: [PATCH] TosaToArith: Use type converter for tosa.const (#195)

* TosaToArith: Use type converter for tosa.const

Co-authored-by: Tina Jung <tinamaria.jung@amd.com>
---
 .../mlir/Conversion/TosaToArith/TosaToArith.h |  4 ++-
 .../Conversion/TosaToArith/TosaToArith.cpp    | 25 +++++++++++++------
 .../TosaToArith/TosaToArithPass.cpp           |  6 ++++-
 .../Conversion/TosaToArith/tosa-to-arith.mlir | 14 +++++++----
 4 files changed, 35 insertions(+), 14 deletions(-)

diff --git a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
index e7158ee3852e1..1d651e394b897 100644
--- a/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
+++ b/mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h
@@ -16,6 +16,7 @@
 #include "mlir/Pass/Pass.h"
 
 namespace mlir {
+class TypeConverter;
 
 #define GEN_PASS_DECL_TOSATOARITH
 #include "mlir/Conversion/Passes.h.inc"
@@ -25,7 +26,8 @@ namespace tosa {
 std::unique_ptr<Pass> createTosaToArith(bool includeApplyRescale = false,
                                         bool use32BitApplyRescale = false);
 
-void populateTosaToArithConversionPatterns(RewritePatternSet *patterns);
+void populateTosaToArithConversionPatterns(TypeConverter &converter,
+                                           RewritePatternSet *patterns);
 
 void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns,
                                                   bool include32Bit = false);
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
index 50e57682a2dc8..fd42d1d444420 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArith.cpp
@@ -15,6 +15,7 @@
 #include "mlir/Dialect/Tosa/IR/TosaOps.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/IR/TypeUtilities.h"
+#include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
 
 using namespace mlir;
@@ -22,13 +23,23 @@ using namespace tosa;
 
 namespace {
 
-class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
+class ConstOpConverter : public OpConversionPattern<tosa::ConstOp> {
 public:
-  using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
+  using OpConversionPattern::OpConversionPattern;
 
-  LogicalResult matchAndRewrite(tosa::ConstOp op,
-                                PatternRewriter &rewriter) const final {
-    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
+  LogicalResult matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor,
+                                ConversionPatternRewriter &rewriter) const final {
+
+    auto elements = dyn_cast<DenseElementsAttr>(adaptor.getValue());
+    if (!elements) {
+       return rewriter.notifyMatchFailure(op, "expected dense elements attr");
+    }
+
+    auto convertedElTy = getTypeConverter()->convertType(elements.getElementType());
+    if (!convertedElTy) {
+      return rewriter.notifyMatchFailure(op, "type conversion failed");
+    }
+    rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, elements.bitcast(convertedElTy));
     return success();
   }
 };
@@ -238,9 +249,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
 
 } // namespace
 
-void mlir::tosa::populateTosaToArithConversionPatterns(
+void mlir::tosa::populateTosaToArithConversionPatterns(TypeConverter &converter,
     RewritePatternSet *patterns) {
-  patterns->add<ConstOpConverter>(patterns->getContext());
+  patterns->add<ConstOpConverter>(converter, patterns->getContext());
 }
 
 void mlir::tosa::populateTosaRescaleToArithConversionPatterns(
diff --git a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
index de82c0335c985..ff3f923b71fbf 100644
--- a/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
+++ b/mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
 
 namespace mlir {
 #define GEN_PASS_DEF_TOSATOARITH
@@ -34,12 +35,15 @@ struct TosaToArith : public impl::TosaToArithBase<TosaToArith> {
   TosaToArith(TosaToArithOptions &options) : TosaToArithBase(options) {}
 
   void runOnOperation() override {
+    TypeConverter converter;
+    mlir::tosa::populateTosaToLinalgTypeConversion(converter);
+
     RewritePatternSet patterns(&getContext());
     ConversionTarget target(getContext());
     target.addIllegalOp<tosa::ConstOp>();
     target.addLegalDialect<arith::ArithDialect>();
 
-    mlir::tosa::populateTosaToArithConversionPatterns(&patterns);
+    mlir::tosa::populateTosaToArithConversionPatterns(converter, &patterns);
 
     if (this->includeApplyRescale) {
       mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns,
diff --git a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
index c4f82d53af982..63d1423ea3ad6 100644
--- a/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
+++ b/mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir
@@ -2,12 +2,16 @@
 // RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s
 
 // CHECK-LABEL: func @const_test
-func.func @const_test() -> (tensor<i32>) {
-  // CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor<i32>
-  %result = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
+func.func @const_test() -> (tensor<i32>, tensor<ui32>) {
+  // CHECK: %[[CI32:.+]] = arith.constant dense<3> : tensor<i32>
+  %i32 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
 
-  // CHECK: return [[C3]]
-  return %result : tensor<i32>
+  // CHECK: %[[CUI32:.+]] = arith.constant dense<3> : tensor<i32>
+  // CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CUI32]] : tensor<i32> to tensor<ui32>
+  %ui32 = "tosa.const"() {value = dense<3> : tensor<ui32>} : () -> tensor<ui32>
+
+  // CHECK: return %[[CI32]], %[[CAST]]
+  return %i32, %ui32 : tensor<i32>,  tensor<ui32>
 }
 
 // -----