Skip to content

[MLIR][AArch64] Lower vector.contract to Neon FEAT_BF16 operations #148198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

momchil-velikov
Copy link
Collaborator

This is split in two commits:

  • refactor I8MM lowering to make it easier to add ...
  • ... BF16 lowering

This patch refactors the pattern in `Transforms/LowerContractionToNeonI8MMPattern.cpp`
using similar approach as in #147052
to prepare for adding BF16 support.
@llvmbot
Copy link
Member

llvmbot commented Jul 11, 2025

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-sve
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-neon

Author: Momchil Velikov (momchil-velikov)

Changes

This is split in two commits:

  • refactor I8MM lowering to make it easier to add ...
  • ... BF16 lowering

Patch is 66.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/148198.diff

12 Files Affected:

  • (modified) mlir/include/mlir/Conversion/Passes.td (+4)
  • (modified) mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td (+13-2)
  • (modified) mlir/include/mlir/Dialect/ArmNeon/Transforms.h (+2-2)
  • (modified) mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp (+3-1)
  • (modified) mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp (+6-1)
  • (modified) mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt (+1-1)
  • (added) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp (+499)
  • (removed) mlir/lib/Dialect/ArmNeon/Transforms/LowerContractionToNeonI8MMPattern.cpp (-364)
  • (modified) mlir/lib/Dialect/ArmSVE/Transforms/LowerContractionToSVEI8MMPattern.cpp (+1-1)
  • (added) mlir/test/Dialect/ArmNeon/vector-bfmmla.mlir (+225)
  • (added) mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-bfmmla.mlir (+176)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/ArmNeon/vector-contract-i8mm.mlir (+1-1)
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 5a864865adffc..4f304b39a0528 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -1437,6 +1437,10 @@ def ConvertVectorToLLVMPass : Pass<"convert-vector-to-llvm"> {
            "bool", /*default=*/"false",
            "Enables the use of Arm FEAT_I8MM instructions while lowering "
            "the vector dialect.">,
+    Option<"armBF16", "enable-arm-bf16",
+           "bool", /*default=*/"false",
+           "Enables the use of Arm FEAT_BF16 instructions while lowering "
+           "the vector dialect.">,
     Option<"x86Vector", "enable-x86vector",
            "bool", /*default=*/"false",
            "Enables the use of X86Vector dialect while lowering the vector "
diff --git a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
index bcaca7da967fa..35747126d3db1 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
+++ b/mlir/include/mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.td
@@ -17,8 +17,19 @@ def ApplyArmNeonContractionToI8MMPatternsOp
          "apply_patterns.arm_neon.vector_contract_to_i8mm",
          [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
   let description = [{
-    Indicates that vector.contract operations should be lowered to
-    finer-grained vector primitives from the ArmNeon dialect.
+    Indicates that vector contract operations should be lowered to
+    to ArmNeon dialect operations mapping to instructions from FEAT_I8MM.
+  }];
+
+  let assemblyFormat = "attr-dict";
+}
+
+def ApplyArmNeonContractionToBFMMLAPatternsOp
+    : Op<Transform_Dialect, "apply_patterns.arm_neon.vector_contract_to_bfmmla",
+         [DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
+  let description = [{
+    Indicates that vector contract operations should be lowered to
+    to ArmNeon dialect operations mapping to instructions from FEAT_BF16.
   }];
 
   let assemblyFormat = "attr-dict";
diff --git a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
index 2f0f634a96770..08065a3b25266 100644
--- a/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
+++ b/mlir/include/mlir/Dialect/ArmNeon/Transforms.h
@@ -13,8 +13,8 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arm_neon {
-void populateLowerContractionToNeonI8MMPatternPatterns(
-    RewritePatternSet &patterns);
+void populateLowerContractionToNeonI8MMPatterns(RewritePatternSet &patterns);
+void populateLowerContractionToNeonBFMMLAPatterns(RewritePatternSet &patterns);
 } // namespace arm_neon
 
 } // namespace mlir
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
index 549d0210af7ad..1045824c437ab 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVMPass.cpp
@@ -84,10 +84,12 @@ void ConvertVectorToLLVMPass::runOnOperation() {
     populateVectorGatherLoweringPatterns(patterns);
     if (armI8MM) {
       if (armNeon)
-        arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+        arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
       if (armSVE)
         populateLowerContractionToSVEI8MMPatternPatterns(patterns);
     }
+    if (armBF16 && armNeon)
+      arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
     (void)applyPatternsGreedily(getOperation(), std::move(patterns));
   }
 
diff --git a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
index d07e6a52d8b5f..d069bde6d9979 100644
--- a/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
+++ b/mlir/lib/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.cpp
@@ -20,7 +20,12 @@ using namespace mlir;
 
 void transform::ApplyArmNeonContractionToI8MMPatternsOp::populatePatterns(
     RewritePatternSet &patterns) {
-  arm_neon::populateLowerContractionToNeonI8MMPatternPatterns(patterns);
+  arm_neon::populateLowerContractionToNeonI8MMPatterns(patterns);
+}
+
+void transform::ApplyArmNeonContractionToBFMMLAPatternsOp::populatePatterns(
+    RewritePatternSet &patterns) {
+  arm_neon::populateLowerContractionToNeonBFMMLAPatterns(patterns);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
index 06bafde451cbb..368dacac7b835 100644
--- a/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/CMakeLists.txt
@@ -1,5 +1,5 @@
 add_mlir_dialect_library(MLIRArmNeonTransforms
-  LowerContractionToNeonI8MMPattern.cpp
+  LowerContractToNeonPatterns.cpp
 
   DEPENDS
   MLIRArmNeonIncGen
diff --git a/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
new file mode 100644
index 0000000000000..06746daa8075b
--- /dev/null
+++ b/mlir/lib/Dialect/ArmNeon/Transforms/LowerContractToNeonPatterns.cpp
@@ -0,0 +1,499 @@
+//===- LowerContractToNeonPatterns.cpp - Contract to I8MM/BF16 --*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements lowering patterns from vector.contract to operations
+// that map to instructions from the Neon FEAT_I8MM extension.
+//
+// TODO: There may be opportunities to unify this with a similar pattern
+// for SVE. See:
+//   https://github.com/llvm/llvm-project/issues/145559
+//   LowerContractionToSVEI8MMPattern.cpp
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Dialect/Arith/IR/Arith.h"
+#include "mlir/Dialect/ArmNeon/ArmNeonDialect.h"
+#include "mlir/Dialect/ArmNeon/Transforms.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+
+#define DEBUG_TYPE "lower-contract-to-arm-neon"
+
+using namespace mlir;
+using namespace mlir::arm_neon;
+
+namespace {
+/// Get the operand of a `vector.contract`. This function is intended to
+/// abstract away from the particular way a value is extended before feeding it
+/// into the `vector.contract` - via zero-extend or an explicit or implicit
+/// sign-extend (for implicit sign-extension see `vector.contract`
+/// documentation).
+///
+/// The template parameter `Op` indicates the extension operation (explicit or
+/// implicit) for which we are checking.
+///
+// Return success only for extensions from `iN` (N <= 8) to `i32`.
+template <typename Op>
+std::optional<Value> getExtOperand(Value v) {
+
+  static_assert(llvm::is_one_of<Op, arith::ExtSIOp, arith::ExtUIOp>::value,
+                "Must be instantiated with either sign- or zero- extension op");
+
+  // If the operand is not defined by an explicit extend operation of the
+  // accepted operation type allow for an implicit sign-extension.
+  auto extOp = dyn_cast_or_null<Op>(v.getDefiningOp());
+  if (!extOp) {
+    if constexpr (std::is_same<Op, arith::ExtSIOp>::value) {
+      auto eltTy = cast<VectorType>(v.getType()).getElementType();
+      if (!eltTy.isSignlessInteger() || eltTy.getIntOrFloatBitWidth() > 8)
+        return {};
+      return v;
+    }
+    return {};
+  }
+
+  // If the operand is defined by an explicit extend operation of the accepted
+  // operation type, check it's extended from `iN` (N <= 8) to `i32`.
+  auto inOp = extOp.getIn();
+  auto inTy = dyn_cast<VectorType>(inOp.getType());
+  if (!inTy)
+    return {};
+  auto inEltTy = inTy.getElementType();
+  if (!inEltTy.isSignlessInteger() || inEltTy.getIntOrFloatBitWidth() > 8)
+    return {};
+
+  auto outTy = dyn_cast<VectorType>(extOp.getType());
+  if (!(outTy && outTy.getElementType().isSignlessInteger(32)))
+    return {};
+
+  return inOp;
+}
+
+/// Helper function to extend a vector with elements iN, N < 8 to
+/// a vector of i8. Do sign extension if the parameter `signExt` is true,
+/// zero extension otherwise.
+Value extendSmallIntVector(Location loc, VectorType srcTy, Value val,
+                           bool signExt, PatternRewriter &rewriter) {
+  Type targetTy = srcTy.clone(rewriter.getI8Type());
+  return signExt ? rewriter.createOrFold<arith::ExtSIOp>(loc, targetTy, val)
+                 : rewriter.createOrFold<arith::ExtUIOp>(loc, targetTy, val);
+}
+
+class VectorContractRewriter {
+protected:
+  // Designate the operation (resp. instruction) used to do sub-tile matrix
+  // multiplications.
+  enum class MMLA {
+    Nop,
+    SignedInt,   // smmla
+    UnsignedInt, // ummla
+    MixedInt,    // usmmla
+    Bfloat       // bfmmla
+  };
+
+  // Lower-level operation to be emitted.
+  MMLA mmlaOp = MMLA::Nop;
+
+  // Indicate if the operands for the ArmNeon dialect operation need to be
+  // swapped. Currently this is needed in order to emulate an "summla"
+  // operation.
+  bool swapOperands = false;
+
+  // The operand tiles. These are not necessarily the operands of
+  // `vector.contract`, for example they could be operands to `arith.extsi`
+  // that is in turn fed into `vector.contract`.
+  Value lhs;
+  Value rhs;
+  Value acc;
+
+  // The dimensions logically corresponding to matrix multiplication of
+  // MxK * KxN -> MxN. The operands and the result do not necessarily have these
+  // shapes, for example RHS could be NxK with a transposing indexing map.
+  int64_t dimM = 0;
+  int64_t dimN = 0;
+  int64_t dimK = 0;
+
+  // Unroll iteration bounds. See documentaiton for `StaticTileOffsetRange`.
+  SmallVector<int64_t> iterationBounds;
+
+  // Sub-tile shape. The algorithm handles operand shapes, which are multiples
+  // of this shape.
+  SmallVector<int64_t> subTileShape;
+
+  // Create the matrix multiply and accumulate operation according to `mmlaOp`.
+  Value createMMLA(PatternRewriter &rewriter, Location loc, Value acc,
+                   Value lhs, Value rhs) {
+
+    if (swapOperands)
+      std::swap(lhs, rhs);
+    switch (mmlaOp) {
+    case MMLA::SignedInt:
+      return rewriter.createOrFold<arm_neon::SmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::UnsignedInt:
+      return rewriter.createOrFold<arm_neon::UmmlaOp>(loc, acc.getType(), acc,
+                                                      lhs, rhs);
+    case MMLA::MixedInt:
+      return rewriter.createOrFold<arm_neon::UsmmlaOp>(loc, acc.getType(), acc,
+                                                       lhs, rhs);
+    case MMLA::Bfloat:
+      return rewriter.create<arm_neon::BfmmlaOp>(loc, acc.getType(), acc, lhs,
+                                                 rhs);
+    case MMLA::Nop:
+      llvm_unreachable("Uninitialized operation type");
+    }
+  }
+
+  // Check common preconditions for applying the patterns and initialize
+  // logical dimensions.
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    // Check iterator types for matrix multiplication.
+    SmallVector<vector::IteratorType> itTypes = op.getIteratorTypesArray();
+    if (!((itTypes.size() == 3 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::parallel &&
+            itTypes[2] == vector::IteratorType::reduction)) ||
+          (itTypes.size() == 2 &&
+           (itTypes[0] == vector::IteratorType::parallel &&
+            itTypes[1] == vector::IteratorType::reduction))))
+      return rewriter.notifyMatchFailure(
+          op, "iterator types do not correspond to matrix multiplication");
+
+    // Avoid 0-D vectors and 1-D rhs:
+    VectorType lhsType = op.getLhsType();
+    VectorType rhsType = op.getRhsType();
+    if (!lhsType.hasRank() || !rhsType.hasRank() || lhsType.getRank() > 2 ||
+        rhsType.getRank() != 2)
+      return rewriter.notifyMatchFailure(op, "Invalid operand rank");
+
+    // This codegen does not work for scalable vectors. Return failure so this
+    // pattern is not accidentally chosen over patterns that lower to ArmSVE.
+    if (lhsType.isScalable() || rhsType.isScalable())
+      return rewriter.notifyMatchFailure(op,
+                                         "Not applicable to scalable vectors");
+
+    // Initialize dimensions and check for a matching K dimension.
+    dimM = lhsType.getDimSize(0);
+    dimN = rhsType.getDimSize(0);
+    dimK = rhsType.getDimSize(1);
+
+    int64_t lhsDimK;
+    if (lhsType.getRank() == 1) {
+      dimM = 1;
+      lhsDimK = lhsType.getDimSize(0);
+    } else {
+      lhsDimK = lhsType.getDimSize(1);
+    }
+
+    if (lhsDimK != dimK)
+      return rewriter.notifyMatchFailure(op, "Dimensions mismatch");
+
+    return success();
+  }
+
+public:
+  void rewrite(vector::ContractionOp op, PatternRewriter &rewriter) {
+    // Create some convenience types.
+    auto inputElementType = cast<ShapedType>(lhs.getType()).getElementType();
+    auto accElementType = cast<ShapedType>(acc.getType()).getElementType();
+    auto inputExpandedType =
+        VectorType::get({2, subTileShape.back()}, inputElementType);
+    auto outputExpandedType = VectorType::get({2, 2}, accElementType);
+
+    // One-dimensional representation of logical sub-tiles as required by the
+    // ArmNeon ops.
+    auto collapsedInputType =
+        VectorType::get(inputExpandedType.getNumElements(), inputElementType);
+    auto collapsedOutputType =
+        VectorType::get(outputExpandedType.getNumElements(), accElementType);
+
+    // Get indexing maps for a more concise/convenient access.
+    auto indexingMaps = op.getIndexingMapsArray();
+    AffineMap &lhsPermutationMap = indexingMaps[0];
+    AffineMap &rhsPermutationMap = indexingMaps[1];
+    AffineMap &accPermutationMap = indexingMaps[2];
+
+    Location loc = op.getLoc();
+
+    // Initial accumulator for the final result. This is the un-tiled result if
+    // tiling is done.
+    Value result = rewriter.create<arith::ConstantOp>(
+        loc, op.getResultType(), rewriter.getZeroAttr(op.getResultType()));
+
+    SmallVector<int64_t, 3> loopOrder = {0, 1};
+    if (iterationBounds.size() == 3)
+      loopOrder.push_back(2);
+
+    // Keep track of the previous accumulator when tiling over K.
+    Value kAcc;
+    for (SmallVector<int64_t> offsets :
+         StaticTileOffsetRange(iterationBounds, subTileShape, loopOrder)) {
+      // Helper to compute the new shape of each operand and extract the slice.
+      auto extractOperand = [&](Value operand, AffineMap permutationMap,
+                                ArrayRef<int64_t> operandOffsets) {
+        SmallVector<int64_t> operandShape = applyPermutationMap(
+            permutationMap, ArrayRef<int64_t>(subTileShape));
+        SmallVector<int64_t> operandStrides(operandOffsets.size(), 1);
+        return rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+            loc, operand, operandOffsets, operandShape, operandStrides);
+      };
+
+      // Extract tiled lhs, rhs, and acc
+      SmallVector<int64_t> lhsOffsets =
+          applyPermutationMap(lhsPermutationMap, ArrayRef<int64_t>(offsets));
+      Value tiledLhs = extractOperand(lhs, lhsPermutationMap, lhsOffsets);
+      SmallVector<int64_t> rhsOffsets =
+          applyPermutationMap(rhsPermutationMap, ArrayRef<int64_t>(offsets));
+      Value tiledRhs = extractOperand(rhs, rhsPermutationMap, rhsOffsets);
+      SmallVector<int64_t> accOffsets =
+          applyPermutationMap(accPermutationMap, ArrayRef<int64_t>(offsets));
+      Value tiledAcc = extractOperand(acc, accPermutationMap, accOffsets);
+
+      // With vecmat, tiled LHS and ACC will contain only one of 2 necessary
+      // rows along dimM. Expand their shapes to match the ArmNeon op.
+      if (dimM == 1) {
+        auto expandRowVector = [&](Value tiledOperand,
+                                   VectorType expandedTypeType) {
+          auto emptyOperand = rewriter.create<arith::ConstantOp>(
+              loc, expandedTypeType, rewriter.getZeroAttr(expandedTypeType));
+          SmallVector<int64_t> offsets(
+              cast<ShapedType>(emptyOperand.getType()).getRank(), 0);
+          SmallVector<int64_t> strides(
+              cast<ShapedType>(tiledOperand.getType()).getRank(), 1);
+          return rewriter.createOrFold<vector::InsertStridedSliceOp>(
+              loc, tiledOperand, emptyOperand, offsets, strides);
+        };
+        tiledLhs = expandRowVector(tiledLhs, inputExpandedType);
+        tiledAcc = expandRowVector(tiledAcc, outputExpandedType);
+      }
+
+      // Transpose ACC if doing signed by unsigned multiplication, because we're
+      // using the instruction for unsigned by signed multiplication with
+      // reversed operands.
+      if (swapOperands)
+        tiledAcc = rewriter.create<vector::TransposeOp>(
+            loc, tiledAcc, ArrayRef<int64_t>({1, 0}));
+
+      // Collapse tiled operands to 1D vectors required by the ArmNeon ops
+      auto collapsedLhs = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledLhs.getLoc(), collapsedInputType, tiledLhs);
+      auto collapsedRhs = rewriter.createOrFold<vector::ShapeCastOp>(
+          tiledRhs.getLoc(), collapsedInputType, tiledRhs);
+
+      bool initialKAcc = offsets.back() == 0;
+      Value collapsedRes;
+      if (!initialKAcc) {
+        collapsedRes = kAcc;
+      } else {
+        collapsedRes = rewriter.createOrFold<vector::ShapeCastOp>(
+            tiledAcc.getLoc(), collapsedOutputType, tiledAcc);
+      }
+
+      // Insert contract op
+      kAcc =
+          createMMLA(rewriter, loc, collapsedRes, collapsedLhs, collapsedRhs);
+
+      // Reshape output back to 2D
+      Value tiledRes = rewriter.createOrFold<vector::ShapeCastOp>(
+          kAcc.getLoc(), tiledAcc.getType(), kAcc);
+
+      // Because of the reversed operands the result is obtained transposed.
+      // Transpose it back,
+      if (swapOperands)
+        tiledRes = rewriter.create<vector::TransposeOp>(
+            loc, tiledRes, ArrayRef<int64_t>({1, 0}));
+
+      // With vecmat, only one row of tiled ACC can be inserted into the final
+      // result
+      if (dimM == 1)
+        tiledRes = rewriter.createOrFold<vector::ExtractOp>(loc, tiledRes, 0);
+
+      // Insert the tiled result back into the non tiled result of the
+      // contract op.
+      SmallVector<int64_t> strides(
+          cast<ShapedType>(tiledRes.getType()).getRank(), 1);
+      result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+          loc, tiledRes, result, accOffsets, strides);
+    }
+
+    rewriter.replaceOp(op, result);
+  }
+};
+
+class VectorContractRewriterI8MM : public VectorContractRewriter {
+public:
+  LogicalResult matchAndInit(vector::ContractionOp op,
+                             PatternRewriter &rewriter) {
+    if (failed(VectorContractRewriter::matchAndInit(op, rewriter)))
+      return failure();
+
+    // Unrolling patterns can handle any [2, 2, 8] shaped multiple of inputs for
+    // tiling.
+    if ((dimM != 1 && dimM % 2 != 0) || dimN % 2 != 0 || dimK % 8 != 0)
+      return rewriter.notifyMatchFailure(op, "Unsupported operand shapes");
+
+    // Check inputs are sign-/zero- extensions from iN (N <= 8) to i32. Get the
+    // values before the extension. All four signed/unsigned combinations for
+    // input operands are supported, but they are lowered to different
+    // operations. Determine which is the appropriate operation to lower to.
+    mmlaOp = MMLA::SignedInt;
+    auto maybeLhs = getExtOperand<arith::ExtSIOp>(op.getLhs());
+    if (!maybeLhs) {
+      mmlaOp = MMLA::UnsignedInt;
+      maybeLhs = getExtOperand<arith::ExtUIOp>(op.getLhs());
+    }
+    if (!maybeLhs)
+      return rewriter.notifyMatc...
[truncated]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants