diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td index 5a864865adffc..4f304b39a0528 100644 --- a/mlir/include/mlir/Conversion/Passes.td +++ b/mlir/include/mlir/Conversion/Passes.td @@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> { "bool", /*default=*/"false", "Enables the use of Arm FEAT_I8MM instructions while lowering " "the vector dialect.">, + Option<"armBF16", "enable-arm-bf16", + "bool", /*default=*/"false", + "Enables the use of Arm FEAT_BF16 instructions while lowering " + "the vector dialect.">, Option<"x86Vector", "enable-x86vector", "bool", /*default=*/"false", "Enables the use of X86Vector dialect while lowering the vector " diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td index bcaca7da967fa..35747126d3db1 100644 --- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td +++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td @@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp "apply_patterns.arm_neon.vector_contract_to_i8mm", [DeclareOpInterfaceMethods]> { let description = [{ - Indicates that vector.contract operations should be lowered to - finer-grained vector primitives from the ArmNeon dialect. + Indicates that vector contract operations should be lowered to + to ArmNeon dialect operations mapping to instructions from FEAT_I8MM. + }]; + + let assemblyFormat = "attr-dict"; +} + +def ApplyArmNeonContractionToBFMMLAPatternsOp + : Op]> { + let description = [{ + Indicates that vector contract operations should be lowered to + to ArmNeon dialect operations mapping to instructions from FEAT_BF16. }]; let assemblyFormat = "attr-dict"; diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h index 2f0f634a96770..08065a3b25266 100644 --- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h +++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h @@ -13,8 +13,8 @@ namespace mlir { class RewritePatternSet; namespace arm_neon { -void populateLowerContractionToNeonI8MMPatternPatterns( - RewritePatternSet &patterns); +void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns); +void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns); } // namespace arm_neon } // namespace mlir diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp index 549d0210af7ad..1045824c437ab 100644 --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp @@ -84,10 +84,12 @@ void ConvertVectorToLLVMPass::runOnOperation() { populateVectorGatherLoweringPatterns(patterns); if (armI8MM) { if (armNeon) - arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns); + arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); if (armSVE) populateLowerContractionToSVEI8MMPatternPatterns(patterns); } + if (armBF16 && armNeon) + arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp index d07e6a52d8b5f..d069bde6d9979 100644 --- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp +++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp @@ -20,7 +20,12 @@ using namespace mlir; void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns( RewritePatternSet &patterns) { - arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns); + arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns); +} + +void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns( + RewritePatternSet &patterns) { + arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns); } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt index 06bafde451cbb..368dacac7b835 100644 --- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt @@ -1,5 +1,5 @@ add_mlir_dialect_library(MLIRArmNeonTransforms - LowerContractionToNeonI8MMPattern.cpp + LowerContractToNeonPatterns.cpp DEPENDS MLIRArmNeonIncGen diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp new file mode 100644 index 0000000000000..06746daa8075b --- /dev/null +++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp @@ -0,0 +1,499 @@ +//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements lowering patterns from vector.contract to operations +// that map to instructions from the Neon FEAT_I8MM extension. +// +// TODO: There may be opportunities to unify this with a similar pattern +// for SVE. See: +// https://github.com/llvm/llvm-project/issues/145559 +// LowerContractionToSVEI8MMPattern.cpp +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" +#include "mlir/Dialect/ArmNeon/Transforms.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +#define DEBUG_TYPE "lower-contract-to-arm-neon" + +using namespace mlir; +using namespace mlir::arm_neon; + +namespace { +/// Get the operand of a `vector.contract`. This function is intended to +/// abstract away from the particular way a value is extended before feeding it +/// into the `vector.contract` - via zero-extend or an explicit or implicit +/// sign-extend (for implicit sign-extension see `vector.contract` +/// documentation). +/// +/// The template parameter `Op` indicates the extension operation (explicit or +/// implicit) for which we are checking. +/// +// Return success only for extensions from `iN` (N <= 8) to `i32`. +template +std::optional getExtOperand(Value v) { + + static_assert(llvm::is_one_of::value, + "Must be instantiated with either sign- or zero- extension op"); + + // If the operand is not defined by an explicit extend operation of the + // accepted operation type allow for an implicit sign-extension. + auto extOp = dyn_cast_or_null(v.getDefiningOp()); + if (!extOp) { + if constexpr (std::is_same::value) { + auto eltTy = cast(v.getType()).getElementType(); + if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8) + return {}; + return v; + } + return {}; + } + + // If the operand is defined by an explicit extend operation of the accepted + // operation type, check it's extended from `iN` (N <= 8) to `i32`. + auto inOp = extOp.getIn(); + auto inTy = dyn_cast(inOp.getType()); + if (!inTy) + return {}; + auto inEltTy = inTy.getElementType(); + if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8) + return {}; + + auto outTy = dyn_cast(extOp.getType()); + if (!(outTy && outTy.getElementType().isSignlessInteger(32))) + return {}; + + return inOp; +} + +/// Helper function to extend a vector with elements iN, N < 8 to +/// a vector of i8. Do sign extension if the parameter `signExt` is true, +/// zero extension otherwise. +Value extendSmallIntVector(Location loc, VectorType srcTy, Value val, + bool signExt, PatternRewriter &rewriter) { + Type targetTy = srcTy.clone(rewriter.getI8Type()); + return signExt ? rewriter.createOrFold(loc, targetTy, val) + : rewriter.createOrFold(loc, targetTy, val); +} + +class VectorContractRewriter { +protected: + // Designate the operation (resp. instruction) used to do sub-tile matrix + // multiplications. + enum class MMLA { + Nop, + SignedInt, // smmla + UnsignedInt, // ummla + MixedInt, // usmmla + Bfloat // bfmmla + }; + + // Lower-level operation to be emitted. + MMLA mmlaOp = MMLA::Nop; + + // Indicate if the operands for the ArmNeon dialect operation need to be + // swapped. Currently this is needed in order to emulate an "summla" + // operation. + bool swapOperands = false; + + // The operand tiles. These are not necessarily the operands of + // `vector.contract`, for example they could be operands to `arith.extsi` + // that is in turn fed into `vector.contract`. + Value lhs; + Value rhs; + Value acc; + + // The dimensions logically corresponding to matrix multiplication of + // MxK * KxN -> MxN. The operands and the result do not necessarily have these + // shapes, for example RHS could be NxK with a transposing indexing map. + int64_t dimM = 0; + int64_t dimN = 0; + int64_t dimK = 0; + + // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`. + SmallVector iterationBounds; + + // Sub-tile shape. The algorithm handles operand shapes, which are multiples + // of this shape. + SmallVector subTileShape; + + // Create the matrix multiply and accumulate operation according to `mmlaOp`. + Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc, + Value lhs, Value rhs) { + + if (swapOperands) + std::swap(lhs, rhs); + switch (mmlaOp) { + case MMLA::SignedInt: + return rewriter.createOrFold(loc, acc.getType(), acc, + lhs, rhs); + case MMLA::UnsignedInt: + return rewriter.createOrFold(loc, acc.getType(), acc, + lhs, rhs); + case MMLA::MixedInt: + return rewriter.createOrFold(loc, acc.getType(), acc, + lhs, rhs); + case MMLA::Bfloat: + return rewriter.create(loc, acc.getType(), acc, lhs, + rhs); + case MMLA::Nop: + llvm_unreachable("Uninitialized operation type"); + } + } + + // Check common preconditions for applying the patterns and initialize + // logical dimensions. + LogicalResult matchAndInit(vector::ContractionOp op, + PatternRewriter &rewriter) { + // Check iterator types for matrix multiplication. + SmallVector itTypes = op.getIteratorTypesArray(); + if (!((itTypes.size() == 3 && + (itTypes[0] == vector::IteratorType::parallel && + itTypes[1] == vector::IteratorType::parallel && + itTypes[2] == vector::IteratorType::reduction)) || + (itTypes.size() == 2 && + (itTypes[0] == vector::IteratorType::parallel && + itTypes[1] == vector::IteratorType::reduction)))) + return rewriter.notifyMatchFailure( + op, "iterator types do not correspond to matrix multiplication"); + + // Avoid 0-D vectors and 1-D rhs: + VectorType lhsType = op.getLhsType(); + VectorType rhsType = op.getRhsType(); + if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 || + rhsType.getRank() != 2) + return rewriter.notifyMatchFailure(op, "Invalid operand rank"); + + // This codegen does not work for scalable vectors. Return failure so this + // pattern is not accidentally chosen over patterns that lower to ArmSVE. + if (lhsType.isScalable() || rhsType.isScalable()) + return rewriter.notifyMatchFailure(op, + "Not applicable to scalable vectors"); + + // Initialize dimensions and check for a matching K dimension. + dimM = lhsType.getDimSize(0); + dimN = rhsType.getDimSize(0); + dimK = rhsType.getDimSize(1); + + int64_t lhsDimK; + if (lhsType.getRank() == 1) { + dimM = 1; + lhsDimK = lhsType.getDimSize(0); + } else { + lhsDimK = lhsType.getDimSize(1); + } + + if (lhsDimK != dimK) + return rewriter.notifyMatchFailure(op, "Dimensions mismatch"); + + return success(); + } + +public: + void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) { + // Create some convenience types. + auto inputElementType = cast(lhs.getType()).getElementType(); + auto accElementType = cast(acc.getType()).getElementType(); + auto inputExpandedType = + VectorType::get({2, subTileShape.back()}, inputElementType); + auto outputExpandedType = VectorType::get({2, 2}, accElementType); + + // One-dimensional representation of logical sub-tiles as required by the + // ArmNeon ops. + auto collapsedInputType = + VectorType::get(inputExpandedType.getNumElements(), inputElementType); + auto collapsedOutputType = + VectorType::get(outputExpandedType.getNumElements(), accElementType); + + // Get indexing maps for a more concise/convenient access. + auto indexingMaps = op.getIndexingMapsArray(); + AffineMap &lhsPermutationMap = indexingMaps[0]; + AffineMap &rhsPermutationMap = indexingMaps[1]; + AffineMap &accPermutationMap = indexingMaps[2]; + + Location loc = op.getLoc(); + + // Initial accumulator for the final result. This is the un-tiled result if + // tiling is done. + Value result = rewriter.create( + loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); + + SmallVector loopOrder = {0, 1}; + if (iterationBounds.size() == 3) + loopOrder.push_back(2); + + // Keep track of the previous accumulator when tiling over K. + Value kAcc; + for (SmallVector offsets : + StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) { + // Helper to compute the new shape of each operand and extract the slice. + auto extractOperand = [&](Value operand, AffineMap permutationMap, + ArrayRef operandOffsets) { + SmallVector operandShape = applyPermutationMap( + permutationMap, ArrayRef(subTileShape)); + SmallVector operandStrides(operandOffsets.size(), 1); + return rewriter.createOrFold( + loc, operand, operandOffsets, operandShape, operandStrides); + }; + + // Extract tiled lhs, rhs, and acc + SmallVector lhsOffsets = + applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); + Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets); + SmallVector rhsOffsets = + applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); + Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets); + SmallVector accOffsets = + applyPermutationMap(accPermutationMap, ArrayRef(offsets)); + Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets); + + // With vecmat, tiled LHS and ACC will contain only one of 2 necessary + // rows along dimM. Expand their shapes to match the ArmNeon op. + if (dimM == 1) { + auto expandRowVector = [&](Value tiledOperand, + VectorType expandedTypeType) { + auto emptyOperand = rewriter.create( + loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); + SmallVector offsets( + cast(emptyOperand.getType()).getRank(), 0); + SmallVector strides( + cast(tiledOperand.getType()).getRank(), 1); + return rewriter.createOrFold( + loc, tiledOperand, emptyOperand, offsets, strides); + }; + tiledLhs = expandRowVector(tiledLhs, inputExpandedType); + tiledAcc = expandRowVector(tiledAcc, outputExpandedType); + } + + // Transpose ACC if doing signed by unsigned multiplication, because we're + // using the instruction for unsigned by signed multiplication with + // reversed operands. + if (swapOperands) + tiledAcc = rewriter.create( + loc, tiledAcc, ArrayRef({1, 0})); + + // Collapse tiled operands to 1D vectors required by the ArmNeon ops + auto collapsedLhs = rewriter.createOrFold( + tiledLhs.getLoc(), collapsedInputType, tiledLhs); + auto collapsedRhs = rewriter.createOrFold( + tiledRhs.getLoc(), collapsedInputType, tiledRhs); + + bool initialKAcc = offsets.back() == 0; + Value collapsedRes; + if (!initialKAcc) { + collapsedRes = kAcc; + } else { + collapsedRes = rewriter.createOrFold( + tiledAcc.getLoc(), collapsedOutputType, tiledAcc); + } + + // Insert contract op + kAcc = + createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs); + + // Reshape output back to 2D + Value tiledRes = rewriter.createOrFold( + kAcc.getLoc(), tiledAcc.getType(), kAcc); + + // Because of the reversed operands the result is obtained transposed. + // Transpose it back, + if (swapOperands) + tiledRes = rewriter.create( + loc, tiledRes, ArrayRef({1, 0})); + + // With vecmat, only one row of tiled ACC can be inserted into the final + // result + if (dimM == 1) + tiledRes = rewriter.createOrFold(loc, tiledRes, 0); + + // Insert the tiled result back into the non tiled result of the + // contract op. + SmallVector strides( + cast(tiledRes.getType()).getRank(), 1); + result = rewriter.createOrFold( + loc, tiledRes, result, accOffsets, strides); + } + + rewriter.replaceOp(op, result); + } +}; + +class VectorContractRewriterI8MM : public VectorContractRewriter { +public: + LogicalResult matchAndInit(vector::ContractionOp op, + PatternRewriter &rewriter) { + if (failed(VectorContractRewriter::matchAndInit(op, rewriter))) + return failure(); + + // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for + // tiling. + if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0) + return rewriter.notifyMatchFailure(op, "Unsupported operand shapes"); + + // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the + // values before the extension. All four signed/unsigned combinations for + // input operands are supported, but they are lowered to different + // operations. Determine which is the appropriate operation to lower to. + mmlaOp = MMLA::SignedInt; + auto maybeLhs = getExtOperand(op.getLhs()); + if (!maybeLhs) { + mmlaOp = MMLA::UnsignedInt; + maybeLhs = getExtOperand(op.getLhs()); + } + if (!maybeLhs) + return rewriter.notifyMatchFailure( + op, "LHS is not a sign- or zero- extended iN, N <= 8"); + + auto maybeRhs = getExtOperand(op.getRhs()); + if (maybeRhs) { + if (mmlaOp == MMLA::UnsignedInt) + mmlaOp = MMLA::MixedInt; + } else { + if (mmlaOp == MMLA::SignedInt) { + mmlaOp = MMLA::MixedInt; + swapOperands = true; + } + maybeRhs = getExtOperand(op.getRhs()); + } + + if (!maybeRhs) + return rewriter.notifyMatchFailure( + op, "RHS is not a sign- or zero- extended iN, N <= 8"); + + lhs = *maybeLhs; + rhs = *maybeRhs; + acc = op.getAcc(); + + // Extend inputs from iN, N < 8 to i8. + Location loc = op.getLoc(); + auto lhsExtInType = cast(lhs.getType()); + if (lhsExtInType.getElementTypeBitWidth() < 8) + lhs = extendSmallIntVector(loc, lhsExtInType, lhs, + /* signExt */ + (mmlaOp == MMLA::SignedInt || + (mmlaOp == MMLA::MixedInt && !swapOperands)), + rewriter); + + auto rhsExtInType = cast(rhs.getType()); + if (rhsExtInType.getElementTypeBitWidth() < 8) + rhs = extendSmallIntVector(loc, rhsExtInType, rhs, + /* signExt */ + (mmlaOp == MMLA::SignedInt || + (mmlaOp == MMLA::MixedInt && swapOperands)), + rewriter); + + // Initialize parameters for unrolling. + iterationBounds = *op.getShapeForUnroll(); + if (iterationBounds.size() == 3) + subTileShape = SmallVector({dimM == 1 ? 1 : 2, 2, 8}); + else + subTileShape = SmallVector({2, 8}); + + return success(); + } +}; + +class VectorContractRewriterBFMMLA : public VectorContractRewriter { +public: + LogicalResult matchAndInit(vector::ContractionOp op, + PatternRewriter &rewriter) { + + if (failed(VectorContractRewriter::matchAndInit(op, rewriter))) + return failure(); + + // Unrolling patterns can handle any [2, 2, 4] shaped multiple of inputs for + // tiling. + if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 4 != 0) + return rewriter.notifyMatchFailure(op, "Unsupported operand shapes"); + + // Check the output is a vector of Float32 elements. + auto outTy = dyn_cast(op.getResultType()); + if (!outTy || outTy.getElementType() != rewriter.getF32Type()) + return rewriter.notifyMatchFailure(op, + "output type is not a vector of f32"); + + // Check the inputs are vectors of BFloat16 elements. + if (op.getLhsType().getElementType() != rewriter.getBF16Type()) + return rewriter.notifyMatchFailure(op, + "input type is not a vector of bf16"); + + mmlaOp = MMLA::Bfloat; + swapOperands = false; + lhs = op.getLhs(); + rhs = op.getRhs(); + acc = op.getAcc(); + + // Initialize parameters for unrolling. + iterationBounds = *op.getShapeForUnroll(); + if (iterationBounds.size() == 3) + subTileShape = SmallVector({dimM == 1 ? 1 : 2, 2, 4}); + else + subTileShape = SmallVector({2, 4}); + + return success(); + } +}; + +/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile +/// any vector.contract into multiple smmla instructions with unrolling so long +/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM +/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is +/// necessary, a single smmla instruction is emitted. +class LowerContractionToNeonI8MMPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + VectorContractRewriterI8MM vcr; + if (failed(vcr.matchAndInit(op, rewriter))) + return failure(); + vcr.rewrite(op, rewriter); + + return success(); + } +}; + +class LowerContractionToNeonBFMMLAPattern + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(vector::ContractionOp op, + PatternRewriter &rewriter) const override { + + VectorContractRewriterBFMMLA vcr; + if (failed(vcr.matchAndInit(op, rewriter))) + return failure(); + vcr.rewrite(op, rewriter); + + return success(); + } +}; + +} // namespace + +void mlir::arm_neon::populateLowerContractionToNeonI8MMPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(context, /*benefit=*/2); +} + +void mlir::arm_neon::populateLowerContractionToNeonBFMMLAPatterns( + RewritePatternSet &patterns) { + MLIRContext *context = patterns.getContext(); + patterns.add(context, /*benefit=*/2); +} diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp deleted file mode 100644 index 7180884c77e98..0000000000000 --- a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp +++ /dev/null @@ -1,364 +0,0 @@ -//===- LowerContractionToNeonI8MMPattern.cpp - Contract to I8MM -*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements lowering patterns from vector.contract to operations -// that map to instructions from the Neon FEAT_I8MM extension. -// -// TODO: There may be opportunities to unify this with a similar pattern -// for SVE. See: -// https://github.com/llvm/llvm-project/issues/145559 -// LowerContractionToSVEI8MMPattern.cpp -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h" -#include "mlir/Dialect/ArmNeon/Transforms.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Dialect/Vector/IR/VectorOps.h" -#include "mlir/IR/AffineMap.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#define DEBUG_TYPE "lower-contract-to-arm-neon" - -using namespace mlir; -using namespace mlir::arm_neon; - -namespace { - -/// Return the shaped type with new element type. -static Type matchContainerType(Type element, Type container) { - if (auto shapedTy = dyn_cast(container)) { - return shapedTy.clone(element); - } - return element; -} - -// Get the operand of a `vector.contract`. This function is intended to abstract -// away from the particular way a value is extended before feeding it into the -// `vector.contract` - via zero-extend or an explicit or implicit sign-extend -// (for implicit sign-extension see `vector.contract` documentation). -// -// The template parameter `Op` indicates the extension operation (explicit or -// implicit) for which we are checking. -// -// Return success only for extensions from `iN` (N <= 8) to `i32`. -template -std::optional getExtOperand(Value v) { - - static_assert(llvm::is_one_of::value, - "Must be instantiated with either sign- or zero- extension op"); - - // If the operand is not defined by an explicit extend operation of the - // accepted operation type allow for an implicit sign-extension. - auto extOp = dyn_cast_or_null(v.getDefiningOp()); - if (!extOp) { - if constexpr (std::is_same::value) { - auto eltTy = cast(v.getType()).getElementType(); - if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8) - return {}; - return v; - } - return {}; - } - - // If the operand is defined by an explicit extend operation of the accepted - // operation type, check it's extended from `iN` (N <= 8) to `i32`. - auto inOp = extOp.getIn(); - auto inTy = dyn_cast(inOp.getType()); - if (!inTy) - return {}; - auto inEltTy = inTy.getElementType(); - if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8) - return {}; - - auto outTy = dyn_cast(extOp.getType()); - if (!(outTy && outTy.getElementType().isSignlessInteger(32))) - return {}; - - return inOp; -} - -// Designate the operation (resp. instruction) used to do sub-tile matrix -// multiplications. -enum class MMLA { - Signed, // smmla - Unsigned, // ummla - Mixed, // usmmla - MixedSwapped // usmmla with LHS and RHS swapped -}; - -// Create the matrix mulitply and accumulate operation according to `op`. -Value createMMLA(PatternRewriter &rewriter, MMLA op, Location loc, - mlir::Type accType, Value acc, Value lhs, Value rhs) { - switch (op) { - case MMLA::Signed: - return rewriter.createOrFold(loc, accType, acc, lhs, - rhs); - case MMLA::Unsigned: - return rewriter.createOrFold(loc, accType, acc, lhs, - rhs); - case MMLA::Mixed: - return rewriter.createOrFold(loc, accType, acc, lhs, - rhs); - case MMLA::MixedSwapped: - // The accumulator comes transposed and the result will be transposed - // later, so all we have to do here is swap the operands. - return rewriter.createOrFold(loc, accType, acc, rhs, - lhs); - } -} - -/// Lowering from a vector::contractOp arm neon smmla intrinsic. This will tile -/// any vector.contract into multiple smmla instructions with unrolling so long -/// as [2,2,8] is a divisor of its shape. It can also process vecmats with dimM -/// = 1 (either explicitly or inferred if LHS has only dimK) If no unrolling is -/// necessary, a single smmla instruction is emitted. -class LowerContractionToNeonI8MMPattern - : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(vector::ContractionOp op, - PatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - // Infer tile sizes from operands. For vecmat, LHS may only have 1 dim. - // Note: RHS is not transposed. - mlir::VectorType lhsType = op.getLhsType(); - mlir::VectorType rhsType = op.getRhsType(); - // Avoid 0-D vectors and 1-D rhs: - if (!lhsType.hasRank() || !rhsType.hasRank() || rhsType.getRank() < 2) - return failure(); - // This codegen does not work for scalable vectors. Return failure so this - // pattern is not accidentally chosen over patterns that lower to ArmSVE. - if (lhsType.isScalable() || rhsType.isScalable()) - return failure(); - auto dimM = lhsType.getRank() == 1 ? 1 : lhsType.getDimSize(0); - auto dimN = rhsType.getDimSize(0); - auto dimK = rhsType.getDimSize(1); - bool isVecmat = dimM == 1 ? true : false; - if (lhsType.getDimSize(lhsType.getRank() - 1) != - rhsType.getDimSize(rhsType.getRank() - 1)) { - return failure(); // dimK mismatch - } - // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for - // tiling. - if ((dimM % 2 != 0 && !isVecmat) || dimN % 2 != 0 || dimK % 8 != 0) { - return failure(); - } - - // Check iterator types for contract. All iterators except inner-most - // dimension must be parallel. - auto iteratorTypes = op.getIteratorTypesArray(); - if (iteratorTypes.size() > 3 || iteratorTypes[iteratorTypes.size() - 1] != - vector::IteratorType::reduction) { - return failure(); - } - if (llvm::any_of(ArrayRef(iteratorTypes).drop_back(1), - [](vector::IteratorType iteratorType) { - return iteratorType != vector::IteratorType::parallel; - })) { - return failure(); - } - - // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the - // values before the extension. All four signed/unsigned combinations for - // input operands are supported, but they are lowered to different - // operations. Determine which is the appropriate operation to lower to. - MMLA mmlaOp = MMLA::Signed; - auto maybeLhs = getExtOperand(op.getLhs()); - if (!maybeLhs) { - mmlaOp = MMLA::Unsigned; - maybeLhs = getExtOperand(op.getLhs()); - } - if (!maybeLhs) - return failure(); - - auto maybeRhs = getExtOperand(op.getRhs()); - if (maybeRhs) { - if (mmlaOp == MMLA::Unsigned) - mmlaOp = MMLA::Mixed; - } else { - if (mmlaOp == MMLA::Signed) - mmlaOp = MMLA::MixedSwapped; - maybeRhs = getExtOperand(op.getRhs()); - } - if (!maybeRhs) - return failure(); - - Value origLhs = *maybeLhs; - Value origRhs = *maybeRhs; - - // Match any iX to i32 for X<8 then turn into an i8 output. Feed into - // following neon instruction. Check inputs for extsi are <=i8 - Value extLhs; - Value extRhs; - if (auto lhsExtInType = dyn_cast(origLhs.getType())) { - if (lhsExtInType.getElementTypeBitWidth() <= 8) { - Type targetLhsExtTy = - matchContainerType(rewriter.getI8Type(), lhsExtInType); - if (mmlaOp == MMLA::Signed || mmlaOp == MMLA::Mixed) - extLhs = rewriter.createOrFold(loc, targetLhsExtTy, - origLhs); - else - extLhs = rewriter.createOrFold(loc, targetLhsExtTy, - origLhs); - } - } - if (auto rhsExtInType = dyn_cast(origRhs.getType())) { - if (rhsExtInType.getElementTypeBitWidth() <= 8) { - Type targetRhsExtTy = - matchContainerType(rewriter.getI8Type(), rhsExtInType); - if (mmlaOp == MMLA::Unsigned || mmlaOp == MMLA::Mixed) - extRhs = rewriter.createOrFold(loc, targetRhsExtTy, - origRhs); - else - extRhs = rewriter.createOrFold(loc, targetRhsExtTy, - origRhs); - } - } - - if (!extLhs || !extRhs) { - return failure(); - } - - // Initial accumulator for the final result. This is the un-tiled result if - // tiling is done. - Value result = rewriter.create( - loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType())); - - SmallVector unrolledSize = *op.getShapeForUnroll(); - SmallVector smmlaShape = {2, 8}; - SmallVector loopOrder = {0, 1}; - if (unrolledSize.size() == 3) { - smmlaShape.insert(smmlaShape.begin(), isVecmat ? 1 : 2); - loopOrder.push_back(2); - } - - // Keep track of the previous accumulator when tiling over K. - Value kAcc; - for (SmallVector offsets : - StaticTileOffsetRange(unrolledSize, smmlaShape, loopOrder)) { - // Helper to compute the new shape of each operand and extract the slice. - auto extractOperand = [&](Value operand, AffineMap permutationMap, - ArrayRef operandOffsets) { - SmallVector operandShape = - applyPermutationMap(permutationMap, ArrayRef(smmlaShape)); - SmallVector operandStrides(operandOffsets.size(), 1); - return rewriter.createOrFold( - loc, operand, operandOffsets, operandShape, operandStrides); - }; - - // Extract tiled lhs, rhs, and acc - AffineMap lhsPermutationMap = op.getIndexingMapsArray()[0]; - SmallVector lhsOffsets = - applyPermutationMap(lhsPermutationMap, ArrayRef(offsets)); - Value tiledLhs = extractOperand(extLhs, lhsPermutationMap, lhsOffsets); - AffineMap rhsPermutationMap = op.getIndexingMapsArray()[1]; - SmallVector rhsOffsets = - applyPermutationMap(rhsPermutationMap, ArrayRef(offsets)); - Value tiledRhs = extractOperand(extRhs, rhsPermutationMap, rhsOffsets); - AffineMap accPermutationMap = op.getIndexingMapsArray()[2]; - SmallVector accOffsets = - applyPermutationMap(accPermutationMap, ArrayRef(offsets)); - Value tiledAcc = - extractOperand(op.getAcc(), accPermutationMap, accOffsets); - - auto inputElementType = - cast(tiledLhs.getType()).getElementType(); - auto accElementType = - cast(tiledAcc.getType()).getElementType(); - auto inputExpandedType = VectorType::get({2, 8}, inputElementType); - auto outputExpandedType = VectorType::get({2, 2}, accElementType); - - // With vecmat, tiled LHS and ACC will contain only one of 2 necessary - // rows along dimM. Expand their shapes to match the smmla op. - if (isVecmat) { - auto expandForSMMLA = [&](Value tiledOperand, - VectorType expandedTypeType) { - auto emptyOperand = rewriter.create( - loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType)); - SmallVector offsets( - cast(emptyOperand.getType()).getRank(), 0); - SmallVector strides( - cast(tiledOperand.getType()).getRank(), 1); - return rewriter.createOrFold( - loc, tiledOperand, emptyOperand, offsets, strides); - }; - tiledLhs = expandForSMMLA(tiledLhs, inputExpandedType); - tiledAcc = expandForSMMLA(tiledAcc, outputExpandedType); - } - - // Transpose ACC if doing signed by unsigned multiplication, because we're - // using the instruction for unsigned by signed multiplication with - // reversed operands. - if (mmlaOp == MMLA::MixedSwapped) - tiledAcc = rewriter.create( - loc, tiledAcc, ArrayRef({1, 0})); - - // Collapse tiled operands to 1D vectors required by smmla intrinsic - auto collapsedInputType = - VectorType::get(inputExpandedType.getNumElements(), inputElementType); - auto collapsedLhs = rewriter.createOrFold( - tiledLhs.getLoc(), collapsedInputType, tiledLhs); - auto collapsedRhs = rewriter.createOrFold( - tiledRhs.getLoc(), collapsedInputType, tiledRhs); - auto collapsedOutputType = - VectorType::get(outputExpandedType.getNumElements(), accElementType); - - bool initialKAcc = offsets.back() == 0; - Value collapsedRes; - if (!initialKAcc) { - collapsedRes = kAcc; - } else { - collapsedRes = rewriter.createOrFold( - tiledAcc.getLoc(), collapsedOutputType, tiledAcc); - } - - // Insert contract op - kAcc = createMMLA(rewriter, mmlaOp, op.getLoc(), collapsedRes.getType(), - collapsedRes, collapsedLhs, collapsedRhs); - - // Reshape output back to 2D - Value tiledRes = rewriter.createOrFold( - kAcc.getLoc(), tiledAcc.getType(), kAcc); - - // Because of the reversed operands the result is obtained transposed. - // Transpose it back, - if (mmlaOp == MMLA::MixedSwapped) - tiledRes = rewriter.create( - loc, tiledRes, ArrayRef({1, 0})); - - // With vecmat, only one row of tiled ACC can be inserted into the final - // result - if (isVecmat) { - tiledRes = rewriter.createOrFold(loc, tiledRes, 0); - } - - // Insert the tiled result back into the non tiled result of the - // contract op. - SmallVector strides( - cast(tiledRes.getType()).getRank(), 1); - result = rewriter.createOrFold( - loc, tiledRes, result, accOffsets, strides); - } - - rewriter.replaceOp(op, result); - return success(); - } -}; - -} // namespace - -void mlir::arm_neon::populateLowerContractionToNeonI8MMPatternPatterns( - RewritePatternSet &patterns) { - MLIRContext *context = patterns.getContext(); - patterns.add(context, /*benefit=*/2); -} diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp index b7703ff0393eb..f7a9499e2db07 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp @@ -12,7 +12,7 @@ // TODO: There may be opportunities to unify this with a similar pattern // for Neon. See: // https://github.com/llvm/llvm-project/issues/145559 -// LowerContractionToNeonI8MMPattern.cpp +// LowerContracToNeonPatterns.cpp // //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir new file mode 100644 index 0000000000000..229c4e5b2dc3a --- /dev/null +++ b/mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir @@ -0,0 +1,225 @@ +// RUN: mlir-opt %s --transform-interpreter | FileCheck %s + +// Test lowering of vector.contract to BFMMLA operations. +// For each iteration [I, J, K] sub-tiles are extracted from offsets as follows: +// LHS: [2*I, 4*K] +// RHS: [2*J, 4*K] +// ACC: [2*I, 2*J] +// Sub-tile insert offsets for the result are as like ACC (there are redundant +// inserts). + +// CHECK-LABEL: func.func @vector_contract_to_bfmmla +// CHECK-SAME: %[[LHS:.+]]: vector<4x8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4x4xf32> + +// %[[INIT_RES:.+]] = arith.constant dense<0.000000e+00> : vector<4x4xf32> + +// Iteration [0, 0, 0] +// Extract sib-tiles from each of LHS, RHS and ACC +// %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> + +// Flatten the operands to fit the `bfmmla` operation types +// %[[T3:.+]] = vector.shape_cast %[[T0]] : vector<2x4xbf16> to vector<8xbf16> +// %[[T4:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16> +// %[[T5:.+]] = vector.shape_cast %[[T2]] : vector<2x2xf32> to vector<4xf32> + +// Perform the matrix multiply and accumulate +// %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T5]], %[[T3]], %[[T4]] : vector<8xbf16> to vector<4xf32> + +// Un-flatten the output sub-tile and inserr into the result +// %[[T7:.+]] = vector.shape_cast %[[K_ACC_0]] : vectK_ACCor<4xf32> to vector<2x2xf32> +// %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T7]], %[[INIT_RES]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [0, 0, 1] +// %[[T9:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T10:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T11:.+]] = vector.shape_cast %[[T9]] : vector<2x4xbf16> to vector<8xbf16> +// %[[T12:.+]] = vector.shape_cast %[[T1]]0 : vector<2x4xbf16> to vector<8xbf16> +// %[[T13:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T1]]1, %[[T1]]2 : vector<8xbf16> to vector<4xf32> +// %[[T14:.+]] = vector.shape_cast %[[T1]]3 : vector<4xf32> to vector<2x2xf32> +// %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T1]]4, %[[TMP_RES_0]] {offsets = [0, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [0, 1, 0] +// %[[T16:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T17:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T18:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// %[[T19:.+]] = vector.shape_cast %[[T1]]6 : vector<2x4xbf16> to vector<8xbf16> +// %[[T20:.+]] = vector.shape_cast %[[T1]]7 : vector<2x4xbf16> to vector<8xbf16> +// %[[T21:.+]] = vector.shape_cast %[[T1]]8 : vector<2x2xf32> to vector<4xf32> +// %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T2]]1, %[[T1]]9, %[[T2]]0 : vector<8xbf16> to vector<4xf32> +// %[[T23:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32> +// %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T2]]3, %[[TMP_RES_1]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [0, 1, 1] +// %[[T25:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T26:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T27:.+]] = vector.shape_cast %[[T2]]5 : vector<2x4xbf16> to vector<8xbf16> +// %[[T28:.+]] = vector.shape_cast %[[T2]]6 : vector<2x4xbf16> to vector<8xbf16> +// %[[T29:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T2]]7, %[[T2]]8 : vector<8xbf16> to vector<4xf32> +// %[[T30:.+]] = vector.shape_cast %[[T2]]9 : vector<4xf32> to vector<2x2xf32> +// %[[TMP_RES_3:.+]] = vector.insert_strided_slice %[[T3]]0, %[[TMP_RES_2]] {offsets = [0, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [1, 0, 0] +// %[[T32:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T33:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T34:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// %[[T35:.+]] = vector.shape_cast %[[T3]]2 : vector<2x4xbf16> to vector<8xbf16> +// %[[T36:.+]] = vector.shape_cast %[[T3]]3 : vector<2x4xbf16> to vector<8xbf16> +// %[[T37:.+]] = vector.shape_cast %[[T3]]4 : vector<2x2xf32> to vector<4xf32> +// %[[K_ACC_2:.+]] = arm_neon.intr.bfmmla %[[T3]]7, %[[T3]]5, %[[T3]]6 : vector<8xbf16> to vector<4xf32> +// %[[T39:.+]] = vector.shape_cast %[[K_ACC_2]] : vector<4xf32> to vector<2x2xf32> +//%[[TMP_RES_4:.+]] = vector.insert_strided_slice %[[T3]]9, %[[TMP_RES_3]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [1, 0, 1] +// %[[T41:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T42:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T43:.+]] = vector.shape_cast %[[T4]]1 : vector<2x4xbf16> to vector<8xbf16> +// %[[T44:.+]] = vector.shape_cast %[[T4]]2 : vector<2x4xbf16> to vector<8xbf16> +// %[[T45:.+]] = arm_neon.intr.bfmmla %[[K_ACC_2]], %[[T4]]3, %[[T4]]4 : vector<8xbf16> to vector<4xf32> +// %[[T46:.+]] = vector.shape_cast %[[T4]]5 : vector<4xf32> to vector<2x2xf32> +//%[[TMP_RES_5:.+]] = vector.insert_strided_slice %[[T4]]6,%[[TMP_RES_4]] {offsets = [2, 0], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [1, 1, 0] +// %[[T48:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T49:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T50:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> +// %[[T51:.+]] = vector.shape_cast %[[T4]]8 : vector<2x4xbf16> to vector<8xbf16> +// %[[T52:.+]] = vector.shape_cast %[[T4]]9 : vector<2x4xbf16> to vector<8xbf16> +// %[[T53:.+]] = vector.shape_cast %[[T5]]0 : vector<2x2xf32> to vector<4xf32> +// %[[K_ACC_3:.+]] = arm_neon.intr.bfmmla %[[T5]]3, %[[T5]]1, %[[T5]]2 : vector<8xbf16> to vector<4xf32> +// %[[T55:.+]] = vector.shape_cast %[[K_ACC_3]] : vector<4xf32> to vector<2x2xf32> +//%[[TMP_RES_6:.+]] = vector.insert_strided_slice %[[T5]]5,%[[TMP_RES_5]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// Iteration [1, 1, 1] +// %[[T57:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T58:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// %[[T59:.+]] = vector.shape_cast %[[T5]]7 : vector<2x4xbf16> to vector<8xbf16> +// %[[T60:.+]] = vector.shape_cast %[[T5]]8 : vector<2x4xbf16> to vector<8xbf16> +// %[[T61:.+]] = arm_neon.intr.bfmmla %[[K_ACC_3]], %[[T5]]9, %[[T6]]0 : vector<8xbf16> to vector<4xf32> +// %[[T62:.+]] = vector.shape_cast %[[T6]]1 : vector<4xf32> to vector<2x2xf32> +// %[[RESULT:.+]] = vector.insert_strided_slice %[[T6]]2,%[[TMP_RES_6]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32> + +// return %[[RESULT]] : vector<4x4xf32> + +func.func @vector_contract_to_bfmmla(%lhs: vector<4x8xbf16>, + %rhs: vector<4x8xbf16>, + %acc: vector<4x4xf32>) -> vector<4x4xf32> { + %0 = vector.contract { indexing_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> + ], + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind + } + %lhs, %rhs, %acc : vector<4x8xbf16>, vector<4x8xbf16> into vector<4x4xf32> + + return %0 : vector<4x4xf32> +} + +// Test lowering of vector.contract, representing vector by matrix multiply and +// accumulate, to BFMMLA operations. + +// For each iteration [J, K] sub-tiles are extracted from offsets as follows: +// LHS: [4*K] +// RHS: [2*J, 4*K] +// ACC: [2*J] +// Sub-tile insert offsets for the result are as like ACC (there are redundant +// inserts). +// CHECK-LABEL: func.func @vector_contract_vecmat_to_bfmmla +// CHECK-SAME: %[[LHS:.+]]: vector<8xbf16>, %[[RHS:.+]]: vector<4x8xbf16>, %[[ACC:.+]]: vector<4xf32>) -> vector<4xf32> { +// CHECK: %[[ACC_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x2xf32> +// CHECK: %[[LHS_PAD_Z:.+]] = arith.constant dense<0.000000e+00> : vector<2x4xbf16> +// CHECK: %[[RES_INIT:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> + +// Iteration [0, 0] +// Extract sub-tiles +// CHECK: %[[T0:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16> +// CHECK: %[[T1:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// CHECK: %[[T2:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [0], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> + +// Pad LHS sub-tile/vector with an extra row of zeroes +// CHECK: %[[T3:.+]] = vector.insert_strided_slice %[[T0]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16> + +// Pad ACC sub-tile/vector with an extra row of zeroes +// CHECK: %[[T4:.+]] = vector.insert_strided_slice %[[T2]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32> + +// Flatten the operands to fit the `bfmmla` operation types +// CHECK: %[[T5:.+]] = vector.shape_cast %[[T3]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T6:.+]] = vector.shape_cast %[[T1]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T7:.+]] = vector.shape_cast %[[T4]] : vector<2x2xf32> to vector<4xf32> + +// Perform the matrix multiply and accumulate +// CHECK: %[[K_ACC_0:.+]] = arm_neon.intr.bfmmla %[[T7]], %[[T5]], %[[T6]] : vector<8xbf16> to vector<4xf32> + +// Un-flatten the output sub-tile +// CHECK: %[[T9:.+]] = vector.shape_cast %[[K_ACC_0]] : vector<4xf32> to vector<2x2xf32> + +// Extract the first rows (the second row is padding) and insert into the result +// CHECK: %[[T10:.+]] = vector.extract %[[T9]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[TMP_RES_0:.+]] = vector.insert_strided_slice %[[T10]], %[[RES_INIT]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + +// Iteration [0, 1] +// CHECK: %[[T12:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16> +// CHECK: %[[T13:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [0, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// CHECK: %[[T14:.+]] = vector.insert_strided_slice %[[T12]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16> +// CHECK: %[[T15:.+]] = vector.shape_cast %[[T14]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T16:.+]] = vector.shape_cast %[[T13]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T17:.+]] = arm_neon.intr.bfmmla %[[K_ACC_0]], %[[T15]], %[[T16]] : vector<8xbf16> to vector<4xf32> +// CHECK: %[[T18:.+]] = vector.shape_cast %[[T17]] : vector<4xf32> to vector<2x2xf32> +// CHECK: %[[T19:.+]] = vector.extract %[[T18]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[TMP_RES_1:.+]] = vector.insert_strided_slice %[[T19]], %[[TMP_RES_0]] {offsets = [0], strides = [1]} : vector<2xf32> into vector<4xf32> + +// Iteration [1, 0] +// CHECK: %[[T21:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [0], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16> +// CHECK: %[[T22:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 0], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// CHECK: %[[T23:.+]] = vector.extract_strided_slice %[[ACC]] {offsets = [2], sizes = [2], strides = [1]} : vector<4xf32> to vector<2xf32> +// CHECK: %[[T24:.+]] = vector.insert_strided_slice %[[T21]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16> +// CHECK: %[[T25:.+]] = vector.insert_strided_slice %[[T23]], %[[ACC_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<2xf32> into vector<2x2xf32> +// CHECK: %[[T26:.+]] = vector.shape_cast %[[T24]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T27:.+]] = vector.shape_cast %[[T22]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T28:.+]] = vector.shape_cast %[[T25]] : vector<2x2xf32> to vector<4xf32> +// CHECK: %[[K_ACC_1:.+]] = arm_neon.intr.bfmmla %[[T28]], %[[T26]], %[[T27]] : vector<8xbf16> to vector<4xf32> +// CHECK: %[[T30:.+]] = vector.shape_cast %[[K_ACC_1]] : vector<4xf32> to vector<2x2xf32> +// CHECK: %[[T31:.+]] = vector.extract %[[T30]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[TMP_RES_2:.+]] = vector.insert_strided_slice %[[T31]], %[[TMP_RES_1]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> + +// Iteration [1, 1] +// CHECK: %[[T33:.+]] = vector.extract_strided_slice %[[LHS]] {offsets = [4], sizes = [4], strides = [1]} : vector<8xbf16> to vector<4xbf16> +// CHECK: %[[T34:.+]] = vector.extract_strided_slice %[[RHS]] {offsets = [2, 4], sizes = [2, 4], strides = [1, 1]} : vector<4x8xbf16> to vector<2x4xbf16> +// CHECK: %[[T35:.+]] = vector.insert_strided_slice %[[T33]], %[[LHS_PAD_Z]] {offsets = [0, 0], strides = [1]} : vector<4xbf16> into vector<2x4xbf16> +// CHECK: %[[T36:.+]] = vector.shape_cast %[[T35]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T37:.+]] = vector.shape_cast %[[T34]] : vector<2x4xbf16> to vector<8xbf16> +// CHECK: %[[T38:.+]] = arm_neon.intr.bfmmla %[[K_ACC_1]], %[[T36]], %[[T37]] : vector<8xbf16> to vector<4xf32> +// CHECK: %[[T39:.+]] = vector.shape_cast %[[T38]] : vector<4xf32> to vector<2x2xf32> +// CHECK: %[[T40:.+]] = vector.extract %[[T39]][0] : vector<2xf32> from vector<2x2xf32> +// CHECK: %[[RESULT:.+]] = vector.insert_strided_slice %[[T40]], %[[TMP_RES_2]] {offsets = [2], strides = [1]} : vector<2xf32> into vector<4xf32> +// CHECK: return %[[RESULT]] : vector<4xf32> +func.func @vector_contract_vecmat_to_bfmmla(%lhs: vector<8xbf16>, + %rhs: vector<4x8xbf16>, + %acc: vector<4xf32>) -> vector<4xf32> { + %0 = vector.contract { indexing_maps = [ + affine_map<(n, k) -> (k)>, + affine_map<(n, k) -> (n, k)>, + affine_map<(n, k) -> (n)> + ], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind + } + %lhs, %rhs, %acc : vector<8xbf16>, vector<4x8xbf16> into vector<4xf32> + + return %0 : vector<4xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%module: !transform.any_op {transform.readonly}) { + %func = transform.structured.match ops{["func.func"]} in %module : (!transform.any_op) -> !transform.op<"func.func"> + + transform.apply_patterns to %func { + transform.apply_patterns.arm_neon.vector_contract_to_bfmmla + } : !transform.op<"func.func"> + + transform.yield + } +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir new file mode 100644 index 0000000000000..b62ae040f364b --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir @@ -0,0 +1,176 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-neon enable-arm-bf16' \ +// DEFINE: --expand-strided-metadata --convert-to-llvm --finalize-memref-to-llvm \ +// DEFINE: --lower-affine --convert-arith-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+bf16" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && FileCheck %s --input-file=%t -check-prefix CHECK-IR && %{run} | FileCheck %s + +#packed_maps = [ + affine_map<(m, n, k) -> (m, k)>, + affine_map<(m, n, k) -> (n, k)>, + affine_map<(m, n, k) -> (m, n)> +] + +// +// Test the lowering of `vector.contract` using the `LowerContractionToNeonBFMMLAPattern` +// +// The operation that the `vector.contract` in this test performs is matrix +// multiplication with accumulate +// OUT = ACC + LHS * RHS +// of two BFloat16 matrices LHS and RHS, and a Float32 matrix ACC into a Float32 OUT. +// +// Tested are calculations as well as that the relevant `ArmNeon` dialect +// operation (`arm_neon.intr.bfmmla`) is emitted. +// +// That pattern above handles (therefore this test prepares) input/output vectors with +// specific shapes: +// * LHS: vector +// * RHS: vector +// * ACC, OUT: vector +// where the M and N are even and K is divisible by 4. +// Note that the RHS is transposed. +// This data layout makes it efficient to load data into SIMD +// registers in the layout expected by BFMMLA instruction. +// Such a `vector.contract` is representative of the code we aim to generate +// by vectorisation of `linalg.mmt4d`. +// +// In this specific test we use M == 4, N == 4, and K == 4. + +// CHECK-IR-LABEL: llvm.func @matrix_by_matrix_mul_and_acc +// CHECK-IR-COUNT-4: arm_neon.intr.bfmmla +func.func @matrix_by_matrix_mul_and_acc() { + + %c0 = arith.constant 0 : index + %c0_f32 = arith.constant 0.0 : f32 + %c0_bf16 = arith.constant 0.0 : bf16 + + // Accumulator test data + %acc_cst = arith.constant dense<[[ 0.7, 1.0, -0.1, 1.8], + [-0.5, 0.9, 0.7, -0.7], + [ 0.5, -1.3, -2.2, 0.1], + [-0.7, 1.0, 1.7, -1.0]]> : vector<4x4xf32> + + %acc_mem = memref.alloc() : memref<4x4xf32> + vector.transfer_write %acc_cst, %acc_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xf32>, memref<4x4xf32> + %acc = vector.transfer_read %acc_mem[%c0, %c0], %c0_f32 {in_bounds = [true, true]} : memref<4x4xf32>, vector<4x4xf32> + + // LHS test data + %lhs_cst = arith.constant dense<[[ 0.1, 0.7, -0.9, 1.3], + [-1.6, 0.7, -0.3, -0.3], + [-0.4, 0.6, 0.8, -0.5], + [-0.6, -1.0, -1.0, -1.0]]> : vector<4x4xbf16> + + %lhs_mem = memref.alloc() : memref<4x4xbf16> + vector.transfer_write %lhs_cst, %lhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> + %lhs = vector.transfer_read %lhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> + + // RHS test data + %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9], + [ 0.5, 1.6, 1.8, 1.6], + [-0.2, 0.4, 1.0, 0.4], + [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16> + + %rhs_mem = memref.alloc() : memref<4x4xbf16> + vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> + %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> + + // Matrix multiplication and accumulate with transposed RHS. + %0 = vector.contract {indexing_maps = #packed_maps, + iterator_types = ["parallel", "parallel", "reduction"], + kind = #vector.kind} %lhs, %rhs, %acc + : vector<4x4xbf16>, vector<4x4xbf16> into vector<4x4xf32> + + // Display the result of the multiplication + vector.print str "Result(BFMMLA):\n" + %u0 = vector.extract %0[0] : vector<4xf32> from vector<4x4xf32> + %u1 = vector.extract %0[1] : vector<4xf32> from vector<4x4xf32> + %u2 = vector.extract %0[2] : vector<4xf32> from vector<4x4xf32> + %u3 = vector.extract %0[3] : vector<4xf32> from vector<4x4xf32> + vector.print %u0 : vector<4xf32> + vector.print %u1 : vector<4xf32> + vector.print %u2 : vector<4xf32> + vector.print %u3 : vector<4xf32> + + return +} + +// Test when the LHS is a one-dimensional vector. +// +// In the vector by matrix case the dhapes ae as follows: +// * LHS: vector +// * RHS: vector +// * ACC, OUT: vector +// N is even and K is divisible by 4. +// In this specific test we use N == 4, and K == 4. + +// CHECK-IR-LABEL: llvm.func @vector_by_matrix_mul_and_acc +// CHECK-IR-COUNT-2: arm_neon.intr.bfmmla +func.func @vector_by_matrix_mul_and_acc() { + %c0 = arith.constant 0 : index + %c0_f32 = arith.constant 0.0 : f32 + %c0_bf16 = arith.constant 0.0 : bf16 + + // Accumulator test data + %acc_cst = arith.constant dense<[0.7, 1.0, -0.1, 1.8]> : vector<4xf32> + + %acc_mem = memref.alloc() : memref<4xf32> + vector.transfer_write %acc_cst, %acc_mem[%c0] {in_bounds = [true] } : vector<4xf32>, memref<4xf32> + %acc = vector.transfer_read %acc_mem[%c0], %c0_f32 {in_bounds = [true]} : memref<4xf32>, vector<4xf32> + + // LHS test data + %lhs_cst = arith.constant dense<[0.1, 0.7, -0.9, 1.3]> : vector<4xbf16> + + %lhs_mem = memref.alloc() : memref<4xbf16> + vector.transfer_write %lhs_cst, %lhs_mem[%c0] {in_bounds = [true] } : vector<4xbf16>, memref<4xbf16> + %lhs = vector.transfer_read %lhs_mem[%c0], %c0_bf16 {in_bounds = [true]} : memref<4xbf16>, vector<4xbf16> + + // RHS test data + %rhs_cst = arith.constant dense<[[ 0.6, 1.3, 0.1, -0.9], + [ 0.5, 1.6, 1.8, 1.6], + [-0.2, 0.4, 1.0, 0.4], + [-1.3, -0.2, -2.2, 0.3]]> : vector<4x4xbf16> + + %rhs_mem = memref.alloc() : memref<4x4xbf16> + vector.transfer_write %rhs_cst, %rhs_mem[%c0, %c0] {in_bounds = [true, true] } : vector<4x4xbf16>, memref<4x4xbf16> + %rhs = vector.transfer_read %rhs_mem[%c0, %c0], %c0_bf16 {in_bounds = [true, true]} : memref<4x4xbf16>, vector<4x4xbf16> + + // Vector by matrix multiplication and accumulate with transposed RHS. + %0 = vector.contract { indexing_maps = [ + affine_map<(n, k) -> (k)>, + affine_map<(n, k) -> (n, k)>, + affine_map<(n, k) -> (n)> + ], + iterator_types = ["parallel", "reduction"], + kind = #vector.kind + } + %lhs, %rhs, %acc : vector<4xbf16>, vector<4x4xbf16> into vector<4xf32> + + // Display the result of the multiplication + vector.print str "Result(BFMMLA, vecmat):\n" + vector.print %0 : vector<4xf32> + + return +} + +func.func @main() { + // CHECK-LABEL: Result(BFMMLA): + // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 ) + // CHECK: ( -0.316515, 0.196875, 0.879375, 1.80924 ) + // CHECK: ( 1.56867, 0.101367, -1.2784, -1.41579 ) + // CHECK: ( -1.56041, -4.30078, 0.0196488, 1.88269 ) + func.call @matrix_by_matrix_mul_and_acc() : () -> () + + // CHECK-LABEL: Result(BFMMLA, vecmat): + // CHECK: ( 0.411922, 2.63254, -0.219259, 3.89965 ) + func.call @vector_by_matrix_mul_and_acc() : () -> () + + return +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir index 1ce55ca05c90e..f6012bbd3d0b2 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir @@ -240,7 +240,7 @@ func.func @test_usmmla() { // Test the operation where LHS is interpreted as signed and RHS is interpreted // as unsigned. In this test we ultimately emit end execute the `usmmla` -// instruction with reversed operands, see `LowerContractionToNeonI8MMPattern.cpp` +// instruction with reversed operands, see `LowerContractoNeonPatterns.cpp` // for more details. // CHECK-IR-LABEL: llvm.func @test_summla