diff --git a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h index 9d32aa17acdcb..cae0408c3d163 100644 --- a/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h +++ b/mlir/include/mlir/Conversion/TosaToLinalg/TosaToLinalg.h @@ -52,7 +52,10 @@ void populateTosaToLinalgConversionPatterns(TypeConverter &converter, /// Populates conversion passes from TOSA dialect to Linalg named operations. void populateTosaToLinalgNamedConversionPatterns( - RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options); + TypeConverter &converter, RewritePatternSet *patterns, + const TosaToLinalgNamedOptions &options); + +void populateTosaToLinalgTypeConversion(TypeConverter &converter); } // namespace tosa } // namespace mlir diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 2bed25210d6e2..f7d82af9d8d49 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -2756,3 +2756,37 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns( // clang-format on } + +void mlir::tosa::populateTosaToLinalgTypeConversion(TypeConverter &converter) { + converter.addConversion([&](Type type) -> std::optional { + if (type.isUnsignedInteger()) { + return IntegerType::get(type.getContext(), type.getIntOrFloatBitWidth(), + IntegerType::SignednessSemantics::Signless); + } + return type; + }); + converter.addConversion([&](TensorType type) -> std::optional { + auto converted = converter.convertType(type.getElementType()); + if (!converted) + return {}; + return type.clone(converted); + }); + converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + if (inputs.size() != 1) + return std::nullopt; + + return builder.create(loc, resultType, inputs) + .getResult(0); + }); + converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType, + ValueRange inputs, + Location loc) -> std::optional { + if (inputs.size() != 1) + return std::nullopt; + + return builder.create(loc, resultType, inputs) + .getResult(0); + }); +} diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp index 694e99fb8f5d1..a89b186dc57ec 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp @@ -760,17 +760,23 @@ class FullyConnectedConverter } }; -class MaxPool2dConverter : public OpRewritePattern { +class MaxPool2dConverter : public OpConversionPattern { public: - using OpRewritePattern::OpRewritePattern; + using OpConversionPattern::OpConversionPattern; - LogicalResult matchAndRewrite(tosa::MaxPool2dOp op, - PatternRewriter &rewriter) const final { + LogicalResult + matchAndRewrite(tosa::MaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { Location loc = op.getLoc(); - Value input = op.getInput(); + Value input = adaptor.getInput(); ShapedType inputTy = cast(input.getType()); - ShapedType resultTy = cast(op.getType()); + bool isUnsigned = + cast(op.getType()).getElementType().isUnsignedInteger(); + ShapedType resultTy = + cast(getTypeConverter()->convertType(op.getType())); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "failed to convert type"); Type resultETy = inputTy.getElementType(); auto dynamicDimsOr = @@ -786,7 +792,10 @@ class MaxPool2dConverter : public OpRewritePattern { resultETy, APFloat::getLargest( cast(resultETy).getFloatSemantics(), true)); - if (isa(resultETy)) + else if (isUnsigned) + initialAttr = rewriter.getIntegerAttr( + resultETy, APInt::getZero(resultETy.getIntOrFloatBitWidth())); + else if (isa(resultETy)) initialAttr = rewriter.getIntegerAttr( resultETy, APInt::getSignedMinValue(resultETy.getIntOrFloatBitWidth())); @@ -823,9 +832,15 @@ class MaxPool2dConverter : public OpRewritePattern { Value fakeWindowDims = rewriter.create(loc, kernel, resultETy); - rewriter.replaceOpWithNewOp( - op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, - filledEmptyTensor, strideAttr, dilationAttr); + if (isUnsigned) { + rewriter.replaceOpWithNewOp( + op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr); + } else { + rewriter.replaceOpWithNewOp( + op, ArrayRef{resultTy}, ValueRange{paddedInput, fakeWindowDims}, + filledEmptyTensor, strideAttr, dilationAttr); + } return success(); } }; @@ -1091,7 +1106,8 @@ class TransposeConverter : public OpRewritePattern { } // namespace void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( - RewritePatternSet *patterns, const TosaToLinalgNamedOptions &options) { + TypeConverter &converter, RewritePatternSet *patterns, + const TosaToLinalgNamedOptions &options) { if (options.preferConv2DKernelLayoutHWCF) { patterns->add>( @@ -1105,11 +1121,13 @@ void mlir::tosa::populateTosaToLinalgNamedConversionPatterns( // clang-format off ConvConverter, DepthwiseConvConverter, - MaxPool2dConverter, AvgPool2dConverter, FullyConnectedConverter, TransposeConverter >(patterns->getContext()); + patterns->add< + MaxPool2dConverter + >(converter, patterns->getContext()); patterns->add< MatMulConverter>(patterns->getContext(), options.useMatmulForSingleBatch); // clang-format on diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp index f0a2285235aab..b01adb0a905a2 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamedPass.cpp @@ -47,6 +47,9 @@ struct TosaToLinalgNamed } void runOnOperation() override { + TypeConverter converter; + mlir::tosa::populateTosaToLinalgTypeConversion(converter); + RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); target.addLegalDialect { void runOnOperation() override { TypeConverter converter; - converter.addConversion([&](Type type) -> std::optional { - if (type.isUnsignedInteger()) { - return IntegerType::get(&getContext(), type.getIntOrFloatBitWidth(), - IntegerType::SignednessSemantics::Signless); - } - return type; - }); - converter.addConversion([&](TensorType type) -> std::optional { - auto converted = converter.convertType(type.getElementType()); - if (!converted) - return {}; - return type.clone(converted); - }); - converter.addConversion( - [&converter](FunctionType ty) -> std::optional { - SmallVector inputs; - if (failed(converter.convertTypes(ty.getInputs(), inputs))) - return std::nullopt; - - SmallVector results; - if (failed(converter.convertTypes(ty.getResults(), results))) - return std::nullopt; - - return FunctionType::get(ty.getContext(), inputs, results); - }); + mlir::tosa::populateTosaToLinalgTypeConversion(converter); RewritePatternSet patterns(&getContext()); ConversionTarget target(getContext()); diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir index e64903671e599..c5b9d96468cae 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir @@ -199,6 +199,19 @@ func.func @max_pool_i8(%arg0: tensor<1x6x34x62xi8>) -> () { return } +// CHECK-LABEL: @max_pool_ui8 +func.func @max_pool_ui8(%arg0: tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> { + // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x6x34x62xui8> to tensor<1x6x34x62xi8> + // CHECK: arith.constant 0 + // CHECK: linalg.pooling_nhwc_max_unsigned + // CHECK-SAME: ins({{.*}} : tensor<1x6x34x62xi8>, tensor<3x3xi8>) + // CHECK-SAME: outs({{.*}} : tensor<1x4x32x62xi8>) + // CHECK-SAME: -> tensor<1x4x32x62xi8> + // CHECK: builtin.unrealized_conversion_cast {{.*}} : tensor<1x4x32x62xi8> to tensor<1x4x32x62xui8> + %0 = tosa.max_pool2d %arg0 {pad = array, kernel = array, stride = array} : (tensor<1x6x34x62xui8>) -> tensor<1x4x32x62xui8> + return %0 : tensor<1x4x32x62xui8> +} + // CHECK-LABEL: @max_pool_i16 func.func @max_pool_i16(%arg0: tensor<1x6x34x62xi16>) -> () { // CHECK: arith.constant -32768