diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index 9ca93ab28daed..bc4ef58cbcd62 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -91,6 +91,50 @@ createConstOpFromZpVal(Operation *op, const int64_t &zp, Type requiredAttrType, op->getLoc(), IntegerAttr::get(requiredAttrType, castedN)); } +/// Return "failure" if the given elementwise operation cannot be converted. +static LogicalResult +isSupportedElementwiseOperation(ConversionPatternRewriter &rewriter, + Operation *op, RankedTensorType resultType) { + auto elementTy = + cast(op->getOperand(0).getType()).getElementType(); + + // tosa::MulOp + if (isa(op)) { + auto shiftVal = cast(op).getShift(); + DenseElementsAttr shiftElem; + if (!matchPattern(shiftVal, m_Constant(&shiftElem))) + return rewriter.notifyMatchFailure(op, "shift value of mul not found"); + + int32_t shift = shiftElem.getValues()[0].getInt(); + if (isa(elementTy) && shift != 0) + return rewriter.notifyMatchFailure(op, + "Cannot have shift value for float"); + return success(); + } + + // tosa::NegateOp + if (isa(op)) { + auto negate = cast(op); + if (failed(negate.getInput1ZeroPoint())) + return rewriter.notifyMatchFailure( + op, "input1 zero point cannot be statically determined"); + if (failed(negate.getOutputZeroPoint())) + return rewriter.notifyMatchFailure( + op, "output zero point cannot be statically determined"); + return success(); + } + + // tosa::CastOp + if (isa(op)) { + if (!elementTy.isIntOrFloat() || + !resultType.getElementType().isIntOrFloat()) + return rewriter.notifyMatchFailure(op, "unsupported type"); + return success(); + } + + return success(); +} + static Value createLinalgBodyCalculationForElementwiseOp( Operation *op, ValueRange args, ArrayRef resultTypes, ConversionPatternRewriter &rewriter) { @@ -139,17 +183,14 @@ static Value createLinalgBodyCalculationForElementwiseOp( auto shiftVal = cast(op).getShift(); DenseElementsAttr shiftElem; if (!matchPattern(shiftVal, m_Constant(&shiftElem))) { - (void)rewriter.notifyMatchFailure(op, "shift value of mul not found"); - return nullptr; + llvm_unreachable("shift value of mul not found"); } int32_t shift = shiftElem.getValues()[0].getInt(); if (isa(elementTy)) { if (shift != 0) { - (void)rewriter.notifyMatchFailure(op, - "Cannot have shift value for float"); - return nullptr; + llvm_unreachable("Cannot have shift value for float"); } return rewriter.create(loc, resultTypes, args[0], args[1]); } @@ -196,16 +237,12 @@ static Value createLinalgBodyCalculationForElementwiseOp( FailureOr maybeInZp = negate.getInput1ZeroPoint(); if (failed(maybeInZp)) { - (void)rewriter.notifyMatchFailure( - op, "input1 zero point cannot be statically determined"); - return nullptr; + llvm_unreachable("input1 zero point cannot be statically determined"); } FailureOr maybeOutZp = negate.getOutputZeroPoint(); if (failed(maybeOutZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return nullptr; + llvm_unreachable("output zero point cannot be statically determined"); } int64_t inZp = *maybeInZp; @@ -548,10 +585,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( if (isa(op)) { Type srcTy = elementTy; Type dstTy = resultTypes.front(); - if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) { - (void)rewriter.notifyMatchFailure(op, "unsupported type"); - return nullptr; - } + if (!srcTy.isIntOrFloat() || !dstTy.isIntOrFloat()) + llvm_unreachable("unsupported type"); bool bitExtend = srcTy.getIntOrFloatBitWidth() < dstTy.getIntOrFloatBitWidth(); @@ -706,8 +741,8 @@ static Value createLinalgBodyCalculationForElementwiseOp( } } - (void)rewriter.notifyMatchFailure( - op, "unhandled op for linalg body calculation for elementwise op"); + llvm_unreachable( + "unhandled op for linalg body calculation for elementwise op"); return nullptr; } @@ -930,17 +965,11 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc, }); } -static LogicalResult -emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, - Operation *operation, ValueRange operands, - ArrayRef targetShape, - const TypeConverter &converter) { +static LogicalResult emitElementwiseComputation( + ConversionPatternRewriter &rewriter, Location loc, Operation *operation, + ValueRange operands, ArrayRef targetShape, + const TypeConverter &converter, RankedTensorType resultType) { // Generate output tensor - auto resultType = cast_or_null( - converter.convertType(operation->getResultTypes().front())); - if (!resultType) { - return rewriter.notifyMatchFailure(operation, "failed to convert type"); - } Value outputTensor = rewriter.create( loc, targetShape, resultType.getElementType()); @@ -967,7 +996,6 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, affineMaps.push_back(rewriter.getMultiDimIdentityMap(rank)); // Emit 'linalg.generic' op - bool encounteredError = false; auto linalgOp = rewriter.create( loc, outputTensor.getType(), operands, outputTensor, affineMaps, getNParallelLoopsAttrs(rank), @@ -975,15 +1003,10 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc, Value opResult = createLinalgBodyCalculationForElementwiseOp( operation, blockArgs.take_front(operation->getNumOperands()), {resultType.getElementType()}, rewriter); - if (!opResult) { - encounteredError = true; - return; - } + assert(opResult && + "unable to create linalg.generic body for elementwise op"); opBuilder.create(loc, opResult); }); - if (encounteredError) - return rewriter.notifyMatchFailure( - operation, "unable to create linalg.generic body for elementwise op"); // Cast 'linalg.generic' result into original result type if needed auto castResult = rewriter.createOrFold( @@ -1008,13 +1031,20 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, ConversionPatternRewriter &rewriter, const TypeConverter &converter) { - // Collect op properties + // Check if operation is supported. assert(operation->getNumResults() == 1 && "elementwise op expects 1 result"); assert(operation->getNumOperands() >= 1 && "elementwise op expects at least 1 operand"); if (!operandsAndResultsRanked(operation)) return rewriter.notifyMatchFailure(operation, "Unranked tensors not supported"); + auto resultType = cast_or_null( + converter.convertType(operation->getResultTypes().front())); + if (!resultType) { + return rewriter.notifyMatchFailure(operation, "failed to convert type"); + } + if (failed(isSupportedElementwiseOperation(rewriter, operation, resultType))) + return failure(); // Lower operation IndexPool indexPool; @@ -1026,7 +1056,7 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands, broadcastDynamicDimensions(rewriter, loc, indexPool, operandsToBroadcast, targetShape, masterOperands); return emitElementwiseComputation(rewriter, loc, operation, broadcastOperands, - targetShape, converter); + targetShape, converter, resultType); } // Returns the constant initial value for a given reduction operation. The @@ -1126,7 +1156,7 @@ static Value createLinalgBodyCalculationForReduceOp(Operation *op, if (isa(op) && elementTy.isInteger(1)) return rewriter.create(loc, args); - return {}; + llvm_unreachable("unhandled reduction op"); } // Performs the match and rewrite for reduction operations. This includes @@ -1142,6 +1172,10 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, return rewriter.notifyMatchFailure(op, "unranked tensors not supported"); auto elementTy = resultTy.getElementType(); + auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); + if (!fillValueAttr) + return rewriter.notifyMatchFailure( + op, "No initial value found for reduction operation"); Value input = op->getOperand(0); SmallVector reduceShape; @@ -1164,11 +1198,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, dynDims) .getResult(); - auto fillValueAttr = createInitialValueForReduceOp(op, elementTy, rewriter); - if (!fillValueAttr) - return rewriter.notifyMatchFailure( - op, "No initial value found for reduction operation"); - auto fillValue = rewriter.create(loc, fillValueAttr); auto filledTensor = rewriter .create(loc, ValueRange{fillValue}, @@ -1212,7 +1241,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, } } - bool didEncounterError = false; linalg::LinalgOp linalgOp = rewriter.create( loc, inputs, outputs, axis, [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange blockArgs) { @@ -1220,8 +1248,7 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, blockArgs[0], isNanIgnoreMode ? blockArgs[2] : blockArgs[1]}; auto result = createLinalgBodyCalculationForReduceOp( op, binaryArgs, elementTy, rewriter); - if (result) - didEncounterError = true; + assert(result && "could not create reduction body"); SmallVector resultsToYield; if (isNanIgnoreMode) { @@ -1247,10 +1274,6 @@ static LogicalResult reduceMatchAndRewriteHelper(OpTy op, uint64_t axis, nestedBuilder.create(loc, resultsToYield); }); - if (!didEncounterError) - return rewriter.notifyMatchFailure( - op, "unable to create linalg.generic body for reduce op"); - if (isNanIgnoreMode) { // Materialize a check to see whether we encountered any non-NaN values, if // we didn't we need to select a tensor of NaNs since the result will just @@ -1358,13 +1381,6 @@ class RescaleConverter : public OpRewritePattern { if (!isa(inputTy.getElementType())) return rewriter.notifyMatchFailure(op, "only support integer type"); - SmallVector dynDims; - for (int i = 0; i < outputTy.getRank(); i++) { - if (outputTy.isDynamicDim(i)) { - dynDims.push_back(rewriter.create(loc, input, i)); - } - } - // The shift and multiplier values. DenseElementsAttr shiftElems; if (!matchPattern(op.getShift(), m_Constant(&shiftElems))) @@ -1376,6 +1392,21 @@ class RescaleConverter : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "tosa.rescale requires constant multiplier input values"); + if (failed(op.getInputZeroPoint())) + return rewriter.notifyMatchFailure( + op, "input zero point cannot be statically determined"); + + if (failed(op.getOutputZeroPoint())) + return rewriter.notifyMatchFailure( + op, "output zero point cannot be statically determined"); + + SmallVector dynDims; + for (int i = 0; i < outputTy.getRank(); i++) { + if (outputTy.isDynamicDim(i)) { + dynDims.push_back(rewriter.create(loc, input, i)); + } + } + llvm::SmallVector shiftValues = llvm::to_vector(shiftElems.getValues()); // explicit cast is required here @@ -1473,23 +1504,10 @@ class RescaleConverter : public OpRewritePattern { int32_t inBitwidth = valueTy.getIntOrFloatBitWidth() > 32 ? 48 : 32; FailureOr maybeIZp = op.getInputZeroPoint(); - if (failed(maybeIZp)) { - (void)rewriter.notifyMatchFailure( - op, "input zero point cannot be statically determined"); - return; - } - auto inputZp = createConstOpFromZpVal( op, *maybeIZp, nestedBuilder.getIntegerType(inBitwidth), nestedBuilder); - FailureOr maybeOZp = op.getOutputZeroPoint(); - if (failed(maybeOZp)) { - (void)rewriter.notifyMatchFailure( - op, "output zero point cannot be statically determined"); - return; - }; - auto outputZp = createConstOpFromZpVal( op, *maybeOZp, nestedBuilder.getI32Type(), nestedBuilder); @@ -1783,6 +1801,15 @@ class GenericResizeConverter : public OpRewritePattern { return rewriter.notifyMatchFailure( op, "tosa.resize mode should be NEAREST_NEIGHBOR or BILINEAR"); + SmallVector scale, offset, border; + if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) || + !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) || + !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) { + return rewriter.notifyMatchFailure( + op, "tosa.resize scale/offset/border should have compile time " + "constant values."); + } + SmallVector affineMaps = { rewriter.getMultiDimIdentityMap(resultTy.getRank())}; auto emptyTensor = b.create(resultTy.getShape(), resultETy, @@ -1810,15 +1837,6 @@ class GenericResizeConverter : public OpRewritePattern { Value inY = b.create(b.getI32Type(), y); Value inX = b.create(b.getI32Type(), x); - SmallVector scale, offset, border; - if (!tosa::getConstShapeValues(op.getScale().getDefiningOp(), scale) || - !tosa::getConstShapeValues(op.getOffset().getDefiningOp(), offset) || - !tosa::getConstShapeValues(op.getBorder().getDefiningOp(), border)) { - return rewriter.notifyMatchFailure( - op, "tosa.resize scale/offset/border should have compile time " - "constant values."); - } - Value yScaleN, yScaleD, xScaleN, xScaleD; yScaleN = b.create(b.getI32IntegerAttr(scale[0])); yScaleD = b.create(b.getI32IntegerAttr(scale[1])); @@ -2204,6 +2222,9 @@ class ArgMaxConverter : public OpRewritePattern { auto inputTy = cast(input.getType()); auto resultTy = cast(argmaxOp.getOutput().getType()); auto inElementTy = inputTy.getElementType(); + if (!isa(inElementTy)) + return rewriter.notifyMatchFailure( + argmaxOp, "unsupported tosa.argmax element type"); auto outElementTy = resultTy.getElementType(); int axis = argmaxOp.getAxis(); auto resultMaxTy = RankedTensorType::get(resultTy.getShape(), inElementTy); @@ -2213,6 +2234,12 @@ class ArgMaxConverter : public OpRewritePattern { argmaxOp, "tosa.arg_max to linalg.* requires integer-like result type"); + auto fillValueMaxAttr = + createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); + if (!fillValueMaxAttr) + return rewriter.notifyMatchFailure( + argmaxOp, "unsupported tosa.argmax element type"); + SmallVector dynDims; for (int i = 0; i < inputTy.getRank(); i++) { if (inputTy.isDynamicDim(i) && i != axis) { @@ -2238,12 +2265,6 @@ class ArgMaxConverter : public OpRewritePattern { .create(loc, resultTy.getShape(), inElementTy, dynDims) .getResult(); - auto fillValueMaxAttr = - createInitialValueForReduceOp(argmaxOp, inElementTy, rewriter); - - if (!fillValueMaxAttr) - return rewriter.notifyMatchFailure( - argmaxOp, "unsupported tosa.argmax element type"); auto fillValueMax = rewriter.create(loc, fillValueMaxAttr); @@ -2267,7 +2288,6 @@ class ArgMaxConverter : public OpRewritePattern { dstExprs.push_back(mlir::getAffineDimExpr(i, rewriter.getContext())); } - bool didEncounterError = false; auto maps = AffineMap::inferFromExprList({srcExprs, dstExprs, dstExprs}, rewriter.getContext()); auto linalgOp = rewriter.create( @@ -2305,8 +2325,7 @@ class ArgMaxConverter : public OpRewritePattern { predicate = rewriter.create( nestedLoc, arith::CmpIPredicate::sgt, newValue, oldValue); } else { - didEncounterError = true; - return; + llvm_unreachable("unsupported tosa.argmax element type"); } auto resultMax = rewriter.create( @@ -2317,10 +2336,6 @@ class ArgMaxConverter : public OpRewritePattern { nestedLoc, ValueRange({resultIndex, resultMax})); }); - if (didEncounterError) - return rewriter.notifyMatchFailure( - argmaxOp, "unsupported tosa.argmax element type"); - rewriter.replaceOp(argmaxOp, linalgOp.getResult(0)); return success(); } @@ -2416,6 +2431,15 @@ class TableConverter : public OpRewritePattern { auto tableElementTy = tableTy.getElementType(); auto resultElementTy = resultTy.getElementType(); + bool isI8_8_8 = inputElementTy.isInteger(8) && + tableElementTy.isInteger(8) && resultElementTy.isInteger(8); + bool isI16_16_32 = inputElementTy.isInteger(16) && + tableElementTy.isInteger(16) && + resultElementTy.isInteger(32); + if (!isI8_8_8 && !isI16_16_32) + return rewriter.notifyMatchFailure( + op, "unable to create body for tosa.table op"); + SmallVector dynDims; for (int i = 0; i < resultTy.getRank(); ++i) { if (inputTy.isDynamicDim(i)) { @@ -2446,8 +2470,7 @@ class TableConverter : public OpRewritePattern { auto inputValue = block->getArgument(0); rewriter.setInsertionPointToStart(block); - if (inputElementTy.isInteger(8) && tableElementTy.isInteger(8) && - resultElementTy.isInteger(8)) { + if (isI8_8_8) { Value index = rewriter.create( loc, rewriter.getIndexType(), inputValue); Value offset = rewriter.create(loc, 128); @@ -2459,8 +2482,7 @@ class TableConverter : public OpRewritePattern { return success(); } - if (inputElementTy.isInteger(16) && tableElementTy.isInteger(16) && - resultElementTy.isInteger(32)) { + if (isI16_16_32) { Value extend = rewriter.create( loc, rewriter.getI32Type(), inputValue); @@ -2516,8 +2538,7 @@ class TableConverter : public OpRewritePattern { } } - return rewriter.notifyMatchFailure( - op, "unable to create body for tosa.table op"); + llvm_unreachable("unable to create body for tosa.table op"); } };