Skip to content

Commit 4582fe6

Browse files
committed
Merge TosaToArith: Use type converter for tosa.const (#195) into HEAD
2 parents a6f8295 + f77637f commit 4582fe6

File tree

4 files changed

+35
-14
lines changed

4 files changed

+35
-14
lines changed

mlir/include/mlir/Conversion/TosaToArith/TosaToArith.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "mlir/Pass/Pass.h"
1717

1818
namespace mlir {
19+
class TypeConverter;
1920

2021
#define GEN_PASS_DECL_TOSATOARITH
2122
#include "mlir/Conversion/Passes.h.inc"
@@ -25,7 +26,8 @@ namespace tosa {
2526
std::unique_ptr<Pass> createTosaToArith(bool includeApplyRescale = false,
2627
bool use32BitApplyRescale = false);
2728

28-
void populateTosaToArithConversionPatterns(RewritePatternSet *patterns);
29+
void populateTosaToArithConversionPatterns(TypeConverter &converter,
30+
RewritePatternSet *patterns);
2931

3032
void populateTosaRescaleToArithConversionPatterns(RewritePatternSet *patterns,
3133
bool include32Bit = false);

mlir/lib/Conversion/TosaToArith/TosaToArith.cpp

+18-7
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,31 @@
1515
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
1616
#include "mlir/IR/PatternMatch.h"
1717
#include "mlir/IR/TypeUtilities.h"
18+
#include "mlir/Transforms/DialectConversion.h"
1819
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1920

2021
using namespace mlir;
2122
using namespace tosa;
2223

2324
namespace {
2425

25-
class ConstOpConverter : public OpRewritePattern<tosa::ConstOp> {
26+
class ConstOpConverter : public OpConversionPattern<tosa::ConstOp> {
2627
public:
27-
using OpRewritePattern<tosa::ConstOp>::OpRewritePattern;
28+
using OpConversionPattern::OpConversionPattern;
2829

29-
LogicalResult matchAndRewrite(tosa::ConstOp op,
30-
PatternRewriter &rewriter) const final {
31-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, op.getValue());
30+
LogicalResult matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor,
31+
ConversionPatternRewriter &rewriter) const final {
32+
33+
auto elements = dyn_cast<DenseElementsAttr>(adaptor.getValue());
34+
if (!elements) {
35+
return rewriter.notifyMatchFailure(op, "expected dense elements attr");
36+
}
37+
38+
auto convertedElTy = getTypeConverter()->convertType(elements.getElementType());
39+
if (!convertedElTy) {
40+
return rewriter.notifyMatchFailure(op, "type conversion failed");
41+
}
42+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, elements.bitcast(convertedElTy));
3243
return success();
3344
}
3445
};
@@ -238,9 +249,9 @@ class ApplyScale32BitOpConverter : public OpRewritePattern<tosa::ApplyScaleOp> {
238249

239250
} // namespace
240251

241-
void mlir::tosa::populateTosaToArithConversionPatterns(
252+
void mlir::tosa::populateTosaToArithConversionPatterns(TypeConverter &converter,
242253
RewritePatternSet *patterns) {
243-
patterns->add<ConstOpConverter>(patterns->getContext());
254+
patterns->add<ConstOpConverter>(converter, patterns->getContext());
244255
}
245256

246257
void mlir::tosa::populateTosaRescaleToArithConversionPatterns(

mlir/lib/Conversion/TosaToArith/TosaToArithPass.cpp

+5-1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "mlir/Pass/PassManager.h"
2020
#include "mlir/Transforms/DialectConversion.h"
2121
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
22+
#include "mlir/Conversion/TosaToLinalg/TosaToLinalg.h"
2223

2324
namespace mlir {
2425
#define GEN_PASS_DEF_TOSATOARITH
@@ -34,12 +35,15 @@ struct TosaToArith : public impl::TosaToArithBase<TosaToArith> {
3435
TosaToArith(TosaToArithOptions &options) : TosaToArithBase(options) {}
3536

3637
void runOnOperation() override {
38+
TypeConverter converter;
39+
mlir::tosa::populateTosaToLinalgTypeConversion(converter);
40+
3741
RewritePatternSet patterns(&getContext());
3842
ConversionTarget target(getContext());
3943
target.addIllegalOp<tosa::ConstOp>();
4044
target.addLegalDialect<arith::ArithDialect>();
4145

42-
mlir::tosa::populateTosaToArithConversionPatterns(&patterns);
46+
mlir::tosa::populateTosaToArithConversionPatterns(converter, &patterns);
4347

4448
if (this->includeApplyRescale) {
4549
mlir::tosa::populateTosaRescaleToArithConversionPatterns(&patterns,

mlir/test/Conversion/TosaToArith/tosa-to-arith.mlir

+9-5
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@
22
// RUN: mlir-opt --split-input-file --tosa-to-arith="include-apply-rescale=false" %s -verify-diagnostics -o -| FileCheck --check-prefix="SCALE" %s
33

44
// CHECK-LABEL: func @const_test
5-
func.func @const_test() -> (tensor<i32>) {
6-
// CHECK: [[C3:%.+]] = arith.constant dense<3> : tensor<i32>
7-
%result = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
5+
func.func @const_test() -> (tensor<i32>, tensor<ui32>) {
6+
// CHECK: %[[CI32:.+]] = arith.constant dense<3> : tensor<i32>
7+
%i32 = "tosa.const"() {value = dense<3> : tensor<i32>} : () -> tensor<i32>
88

9-
// CHECK: return [[C3]]
10-
return %result : tensor<i32>
9+
// CHECK: %[[CUI32:.+]] = arith.constant dense<3> : tensor<i32>
10+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[CUI32]] : tensor<i32> to tensor<ui32>
11+
%ui32 = "tosa.const"() {value = dense<3> : tensor<ui32>} : () -> tensor<ui32>
12+
13+
// CHECK: return %[[CI32]], %[[CAST]]
14+
return %i32, %ui32 : tensor<i32>, tensor<ui32>
1115
}
1216

1317
// -----

0 commit comments

Comments
 (0)