From 0c21e5461dc81fd4cef58047b53d3296b338c263 Mon Sep 17 00:00:00 2001 From: Vivian Zhang Date: Fri, 14 Feb 2025 05:47:24 -0800 Subject: [PATCH] Support matmul 4d inputs on pack-peel-4-level-tiling (#1098) This PR adds basic support for matmul with 4d inputs and output on pack-peel-4-level-tiling. In theory, the tiling strategy and pipeline should support different layouts of matmul 4d operations. However, the order of the input dims (both inner and outer dims) are crucial for the correct compilation and results. To ensure the correctness and for comparison purpose, this PR only adds a test that corresponds to the standard matmul, which means the order of the input dims for this operation corresponds to the L2 shapes of the matmul op after the first level packing, i.e., ```C += matmul4d(A,B) where A:MxKxM0xK0, B:NxKxK0xN0, C:NxMxM0xN0``` The test class and instance added in run.py is preliminary and for experimental purpose. Generalization of the test class will be addressed as follow-ups. Runtime comparison: On Phoenix runner `matmul_512_4096_512_bf16_f32 : 1141 us vs matmul4d_16_128_8_bf16_f32: 998 us` On Strix runner `matmul_512_4096_512_i8_i32 : 806 us vs matmul4d_16_128_8_i8_i32: 491 us` --- .../matmul4d_MxKxM0xK0_NxKxK0xN0.mlir | 17 +++ .../matmul_template/matmul_generator.py | 4 +- build_tools/ci/cpu_comparison/run.py | 111 ++++++++++++++++-- .../AMDAIEBufferizeToAllocation.cpp | 2 +- .../Transforms/AMDAIEFuseConsumerIntoLoop.cpp | 2 +- .../AMDAIELowerExecutableTarget.cpp | 9 +- .../Transforms/KernelDispatch.cpp | 98 ++++++++++------ .../iree-amd-aie/Transforms/Passes.cpp | 50 +++++--- .../AMD-AIE/iree-amd-aie/Transforms/Passes.h | 3 +- .../Transforms/Utils/AMDAIEUtils.cpp | 42 +++---- .../Transforms/Utils/AMDAIEUtils.h | 8 +- .../test/lowering_strategy_generic.mlir | 33 ++++++ 12 files changed, 282 insertions(+), 97 deletions(-) create mode 100644 build_tools/ci/cpu_comparison/matmul_template/matmul4d_MxKxM0xK0_NxKxK0xN0.mlir diff --git a/build_tools/ci/cpu_comparison/matmul_template/matmul4d_MxKxM0xK0_NxKxK0xN0.mlir b/build_tools/ci/cpu_comparison/matmul_template/matmul4d_MxKxM0xK0_NxKxK0xN0.mlir new file mode 100644 index 000000000..76ef7dd63 --- /dev/null +++ b/build_tools/ci/cpu_comparison/matmul_template/matmul4d_MxKxM0xK0_NxKxK0xN0.mlir @@ -0,0 +1,17 @@ +// input ${M}x${K}x32x64x${TYPE1} +// input ${N}x${K}x64x32x${TYPE1} + +func.func @matmul4d(%arg0: tensor<${M}x${K}x32x64x${TYPE1}>, %arg1: tensor<${N}x${K}x64x32x${TYPE1}>) -> tensor<${N}x${M}x32x32x${TYPE2}> { + %cst = arith.constant ${ZERO} : ${TYPE2} + %0 = tensor.empty() : tensor<${N}x${M}x32x32x${TYPE2}> + %1 = linalg.fill ins(%cst : ${TYPE2}) outs(%0 : tensor<${N}x${M}x32x32x${TYPE2}>) -> tensor<${N}x${M}x32x32x${TYPE2}> + %2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<${M}x${K}x32x64x${TYPE1}>, tensor<${N}x${K}x64x32x${TYPE1}>) outs(%1 : tensor<${N}x${M}x32x32x${TYPE2}>) { + ^bb0(%in: ${TYPE1}, %in_1: ${TYPE1}, %out: ${TYPE2}): + %12 = ${EXT} %in : ${TYPE1} to ${TYPE2} + %13 = ${EXT} %in_1 : ${TYPE1} to ${TYPE2} + %14 = ${MUL} %12, %13 : ${TYPE2} + %15 = ${ADD} %out, %14 : ${TYPE2} + linalg.yield %15 : ${TYPE2} + } -> tensor<${N}x${M}x32x32x${TYPE2}> + return %2 : tensor<${N}x${M}x32x32x${TYPE2}> +} diff --git a/build_tools/ci/cpu_comparison/matmul_template/matmul_generator.py b/build_tools/ci/cpu_comparison/matmul_template/matmul_generator.py index adc42cfb9..7fdd36b6d 100644 --- a/build_tools/ci/cpu_comparison/matmul_template/matmul_generator.py +++ b/build_tools/ci/cpu_comparison/matmul_template/matmul_generator.py @@ -1,6 +1,4 @@ -import sys import re -import os def get_higher_order_element_type(element_type): @@ -29,6 +27,8 @@ def generate_matmul_test(output_fn, input_fn, m, n, k, lhs_rhs_type, acc_type, b acc_is_int = acc_type[0] == "i" replace["ZERO"] = 0 if acc_is_int else 0.0 replace["ADD"] = "arith.addi" if acc_is_int else "arith.addf" + replace["MUL"] = "arith.muli" if acc_is_int else "arith.mulf" + replace["EXT"] = "arith.extsi" if acc_is_int else "arith.extf" key_map = map(lambda s: "${" + s + "}", replace.keys()) key_map_escaped = map(re.escape, key_map) diff --git a/build_tools/ci/cpu_comparison/run.py b/build_tools/ci/cpu_comparison/run.py index 3b8c585ef..adc310cbe 100755 --- a/build_tools/ci/cpu_comparison/run.py +++ b/build_tools/ci/cpu_comparison/run.py @@ -19,7 +19,6 @@ from input_generator import ( generate_inputs, verify_determinism, - load_input, get_output_type, np_from_binfile, ) @@ -486,6 +485,59 @@ def _execute(self, config): return self.vs_cpu(config) +class Matmul4d(BaseMatmul): + """ + A test of linalg.generic with 4d inputs and output implementing form: + C += matmul4d(A,B) where A:MxKxM0xK0, B:NxKxK0xN0, C:NxMxM0xN0 + + Note that the outer dims for this operation are transposed to make sure + successful compilation through LogicalObjectFifo pipeline. + For comparison purpose, the input values of inner dims M0/N0/K0 are + fixed as 32/32/64 currently. + TODO(vivian): Generalize the class and the template. + """ + + def __init__( + self, + M, + N, + K, + input_type, + acc_type, + additional_labels=None, + n_kernel_runs=1, + test_params=None, + ): + super().__init__( + name=f"matmul4d_{M}_{N}_{K}_{input_type}_{acc_type}", + test_params=test_params, + M=M, + N=N, + K=K, + input_type=input_type, + acc_type=acc_type, + function_name="matmul4d", + n_kernel_runs=n_kernel_runs, + ) + self.labels.append("Matmul4d") + if additional_labels: + self.labels += additional_labels + if self.run_benchmark: + self.aie_compilation_flags += [ + "--iree-amdaie-enable-infinite-loop-around-core-block=true" + ] + self.labels.append("Matmul4dBenchmark") + + def _execute(self, config): + matmul_template_dir = config.file_dir / "matmul_template" + template_name = matmul_template_dir / "matmul4d_MxKxM0xK0_NxKxK0xN0.mlir" + self.generate(config, template_name) + if self.run_benchmark: + return self.benchmark(config) + + return self.vs_cpu(config) + + class MatmulThinBias(BaseMatmul): """ A test of the form matmul(A,B) + C where A:MxK, B:KxN, C:N @@ -502,9 +554,11 @@ def __init__( ): super().__init__( name=f"matmul_thin_bias_{M}_{N}_{K}_{input_type}_{acc_type}", - test_params=test_params - if test_params is not None - else TestParams(lower_to_aie_pipeline="air"), + test_params=( + test_params + if test_params is not None + else TestParams(lower_to_aie_pipeline="air") + ), M=M, N=N, K=K, @@ -543,9 +597,11 @@ def __init__( ): super().__init__( name=f"matmul_full_bias_{M}_{N}_{K}_{input_type}_{acc_type}", - test_params=test_params - if test_params is not None - else TestParams(lower_to_aie_pipeline="air"), + test_params=( + test_params + if test_params is not None + else TestParams(lower_to_aie_pipeline="air") + ), M=M, N=N, K=K, @@ -565,8 +621,7 @@ def _execute(self, config): "--iree-amdaie-num-cols=2", ] ) - self.vs_cpu(config) - return True + return self.vs_cpu(config) class BatchMatmul(BaseMatmul): @@ -2194,6 +2249,21 @@ def __init__(self): "transpose_b": False, "tile_pipeline": "pack-peel-4-level-tiling", }, + # matmul4d test where the input M/N/K are outer dim values. + # The total input values correspond to a standard matmul + # from the above test are M:512, N:4096, K:512. + { + "M": 16, + "N": 128, + "K": 8, + "use_ukernel": True, + "peano_opt_level": 3, + "outline": "balanced", + "transpose_a": False, + "transpose_b": False, + "matmul4d": True, + "tile_pipeline": "pack-peel-4-level-tiling", + }, # Test where the compute is omitted, this should help triangulate # how much performance gain can be obtained with better matmul # on core vs data movement. @@ -2257,6 +2327,24 @@ def __init__(self): "tile_pipeline": "pack-peel-4-level-tiling", "run_on_target": "npu4", }, + # matmul4d test where the input M/N/K are outer dim values. + # The total input values correspond to a standard matmul + # from the above test are M:512, N:4096, K:512. + { + "M": 16, + "N": 128, + "K": 8, + "in_dtype": "i8", + "out_dtype": "i32", + "use_ukernel": True, + "peano_opt_level": 3, + "outline": "all", + "transpose_a": False, + "transpose_b": False, + "matmul4d": True, + "tile_pipeline": "pack-peel-4-level-tiling", + "run_on_target": "npu4", + }, { "M": 512, "N": 4096, @@ -2301,6 +2389,7 @@ def __init__(self): transpose_a = test["transpose_a"] transpose_b = test["transpose_b"] tile_pipeline = test["tile_pipeline"] + matmul4d = test["matmul4d"] if "matmul4d" in test else False run_on_target = ( test["run_on_target"] if "run_on_target" in test else "npu1_4col" ) @@ -2343,7 +2432,9 @@ def __init__(self): else: name_suffix += "_outline" - if (transpose_a, transpose_b) == (False, False): + if matmul4d: + TestClass = Matmul4d + elif (transpose_a, transpose_b) == (False, False): TestClass = Matmul elif (transpose_a, transpose_b) == (True, False): TestClass = MatmulTransposeA diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEBufferizeToAllocation.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEBufferizeToAllocation.cpp index 40b726691..969fa9f34 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEBufferizeToAllocation.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEBufferizeToAllocation.cpp @@ -172,7 +172,7 @@ void AMDAIEBufferizeToAllocationPass::runOnOperation() { !linalg::isaConvolutionOpInterface(op)) { return WalkResult::advance(); } - if (isa(op)) { + if (isa(op)) { return WalkResult::advance(); } // Use flag `bufferizeElementwise` to indicate whether the target for diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp index 927c9bd5d..f119b31aa 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIEFuseConsumerIntoLoop.cpp @@ -48,7 +48,7 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() { // Check if there is matmul-elementwise fusion opportunity. If so, overwrite // the `fuseDepth` to be 2. funcOp->walk([&](linalg::LinalgOp op) { - if (isMatmulProducerOfElementwise(op)) { + if (isElementwiseWithMatmulProducer(op)) { fuseDepth = 2; return WalkResult::interrupt(); } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerExecutableTarget.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerExecutableTarget.cpp index 8fbfd6434..eb981a073 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerExecutableTarget.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerExecutableTarget.cpp @@ -57,6 +57,13 @@ class AMDAIELowerExecutableTargetPass }; } // namespace +static Operation *getRootOp(FunctionOpInterface funcOp) { + SmallVector computeOps = getComputeOps(funcOp); + FailureOr rootOp = getRootOperation(computeOps); + assert(succeeded(rootOp) && "Pipeline requires a root operation"); + return rootOp.value(); +} + void AMDAIELowerExecutableTargetPass::runOnOperation() { auto funcOp = getOperation(); auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp); @@ -82,7 +89,7 @@ void AMDAIELowerExecutableTargetPass::runOnOperation() { TilePassPipeline::PackPeel4LevelTilingPipeline) { addPackPeel4LevelTilingBasedPassPipeline( executableLoweringPipeline, pathToUkernels, - TilePassPipeline::PackPeel4LevelTilingPipeline); + TilePassPipeline::PackPeel4LevelTilingPipeline, getRootOp(funcOp)); } else if (useTilePipeline == TilePassPipeline::PadPackPipeline) { addPadPackBasedPassPipeline(executableLoweringPipeline, pathToUkernels, enableVectorizationPasses, diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp index fb54f8693..ae1ba0e09 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/KernelDispatch.cpp @@ -468,47 +468,62 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline( MLIRContext *context = entryPointFn.getContext(); unsigned numLoops = linalgOp.getNumLoops(); - // Pack level => 1. - SmallVector packedSizesL0(numLoops, 0); - packedSizesL0[mDims.back()] = packPeelTiling.m0Pack; - packedSizesL0[nDims.back()] = packPeelTiling.n0Pack; - packedSizesL0[kDims.back()] = packPeelTiling.k0Pack; - - // For matmul, transpose B matrix from [K N n k] to [N K k n] - // For matmul_transpose_b, we don't have to transpose the B matrix, - // since it is already [N K n k] - SmallVector transposePackIndices = {0, 1, 2}; - // There is no corresponding unpack for the specified pack operation - // 0 is used when unpack is empty - SmallVector unpackEmpty = {false, false, true}; + bool isBatchMatmul = isa(linalgOp); SmallVector innerPermA = setInnerPermA(isMatmulTransposeA(linalgOp)); SmallVector innerPermB = setInnerPermB(isMatmulTransposeB(linalgOp)); - SmallVector> innerPerm = { - innerPermA, innerPermB, {0, 1}}; - bool isBatchMatmul = isa(linalgOp); SmallVector outerPermA = setOuterPermA(isMatmulTransposeA(linalgOp), isBatchMatmul); SmallVector outerPermB = setOuterPermB(isMatmulTransposeB(linalgOp), isBatchMatmul); - SmallVector> outerPerm = {outerPermA, outerPermB}; - // Add outer permutation for unpack. NOTE: This currently fails for some - // tests in the AIR pipeline. - if (isa(linalgOp)) { - outerPerm.push_back({0, 2, 1}); - } else { - outerPerm.push_back({1, 0}); - } - auto packingConfigLevel0Attr = getPackingConfigPackingLevelAttr( - context, packedSizesL0, transposePackIndices, unpackEmpty, innerPerm, - outerPerm); + SmallVector transposePackIndices; + SmallVector unpackEmpty; + SmallVector> innerPerm; + SmallVector> outerPerm; + SmallVector packingConfigLevelsVal; + + // Pack level => 1. + // For 2D matmul-like ops, the first level is to pack operands from 2D to 4D. + // If the input is a 4D matmul-like op, this level of packing is not needed. + bool is2DMatmulLike = is2DMatmulLikeOp(linalgOp) || isBatchMatmul; + if (is2DMatmulLike) { + SmallVector packedSizesL0(numLoops, 0); + packedSizesL0[mDims.back()] = packPeelTiling.m0Pack; + packedSizesL0[nDims.back()] = packPeelTiling.n0Pack; + packedSizesL0[kDims.back()] = packPeelTiling.k0Pack; + + transposePackIndices = {0, 1, 2}; + // There is no corresponding unpack for the specified pack operation + // 0 is used when unpack is empty + unpackEmpty = {false, false, true}; + innerPerm = {innerPermA, innerPermB, {0, 1}}; + outerPerm = {outerPermA, outerPermB}; + // Add outer permutation for unpack. NOTE: This currently fails for some + // tests in the AIR pipeline. + if (isBatchMatmul) { + outerPerm.push_back({0, 2, 1}); + } else { + outerPerm.push_back({1, 0}); + } + + auto packingConfigLevel0Attr = getPackingConfigPackingLevelAttr( + context, packedSizesL0, transposePackIndices, unpackEmpty, innerPerm, + outerPerm); + packingConfigLevelsVal.push_back(packingConfigLevel0Attr); + } // Pack level => 2. - // The number of loops have increased by 3 due to the first level pack. - SmallVector packedSizesL1(numLoops + 3, 0); - packedSizesL1[mDims.back() + 3] = packPeelTiling.m1Pack; - packedSizesL1[nDims.back() + 3] = packPeelTiling.n1Pack; - packedSizesL1[kDims.back() + 3] = packPeelTiling.k1Pack; + // If the first level pack exists (for 2D matmul-like ops), the number of + // packed dimensions should increase by 3, otherwise keep the original + // number of loops. + unsigned numPackedDims = is2DMatmulLike ? numLoops + 3 : numLoops; + unsigned mIdx = is2DMatmulLike ? mDims.back() + 3 : mDims.back(); + unsigned nIdx = is2DMatmulLike ? nDims.back() + 3 : nDims.back(); + unsigned kIdx = is2DMatmulLike ? kDims.back() + 3 : kDims.back(); + SmallVector packedSizesL1(numPackedDims, 0); + packedSizesL1[mIdx] = packPeelTiling.m1Pack; + packedSizesL1[nIdx] = packPeelTiling.n1Pack; + packedSizesL1[kIdx] = packPeelTiling.k1Pack; // Transpose A matrix from [M K m k m0 k0] to [M K k m m0 k0] // Transpose C matrix from [M N m n m0 n0] to [M N n m m0 n0] @@ -519,7 +534,7 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline( // Only the third pack operation has a corresponding unpack operation unpackEmpty = {false, false, true}; innerPerm = {innerPermA, innerPermB, {0, 1}}; - if (isa(linalgOp)) { + if (isBatchMatmul) { outerPerm = {{0, 1, 2, 4, 3}, {0, 1, 2, 4, 3}, {0, 1, 2, 4, 3}}; } else { outerPerm = {{0, 1, 3, 2}, {0, 1, 3, 2}, {0, 1, 3, 2}}; @@ -527,9 +542,8 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline( auto packingConfigLevel1Attr = getPackingConfigPackingLevelAttr( context, packedSizesL1, transposePackIndices, unpackEmpty, innerPerm, outerPerm); + packingConfigLevelsVal.push_back(packingConfigLevel1Attr); - SmallVector packingConfigLevelsVal = { - packingConfigLevel0Attr, packingConfigLevel1Attr}; auto packingConfigLevels = PackingConfigPackingLevelsAttr::get(context, packingConfigLevelsVal); auto config = PackingConfigAttr::get(context, packingConfigLevels); @@ -550,14 +564,22 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline( bool fitsInL2 = (l2SizeA + l2SizeB + l2SizeInit) < (deviceModel.getMemTileSizeInBytes() * numCols); int64_t scaleL0 = !isBatchMatmul && fitsInL2 ? 2 : 1; + int64_t m0Tile = packPeelTiling.M0 * scaleL0; + int64_t n0Tile = packPeelTiling.N0 * scaleL0; SmallVector tileSizeLevel0(numLoops, 0); - if (isa(linalgOp)) { + if (isBatchMatmul) { assert(!batchDims.empty() && "expected batch dims not empty"); tileSizeLevel0[batchDims[0]] = 1; } - tileSizeLevel0[mDims[0]] = packPeelTiling.M0 * scaleL0; - tileSizeLevel0[nDims[0]] = packPeelTiling.N0 * scaleL0; + // For 4D matmul-like ops, only tile the outer dims. + // outer_tile_size = total_tile_size / inner_dim_size + if (is4DMatmulLikeOp(linalgOp)) { + m0Tile /= maybeInputDimsAndSizes.value().mSizes.back(); + n0Tile /= maybeInputDimsAndSizes.value().nSizes.back(); + } + tileSizeLevel0[mDims[0]] = m0Tile; + tileSizeLevel0[nDims[0]] = n0Tile; SmallVector tileSizeLevel1(numLoops, 0); tileSizeLevel1[mDims[0]] = numRows; diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index c79c87fc7..f76e98264 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -21,6 +21,7 @@ #include "air/Transform/AffineLoopOptPass.h" #include "iree-amd-aie/IR/AMDAIEAttrs.h" #include "iree-amd-aie/Transforms/Passes.h" +#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h" #include "iree-dialects/Dialect/LinalgTransform/Passes.h" #include "iree/compiler/Codegen/Common/Passes.h" #include "iree/compiler/Utils/ToolUtils.h" @@ -309,9 +310,13 @@ void addPackPeelBasedPassPipeline(OpPassManager &funcPassManager, funcPassManager.addPass(createHoistStaticallyBoundAllocationsPass()); } -void addPackPeel4LevelTilingBasedPassPipeline( - OpPassManager &funcPassManager, const std::string &pathToUkernels, - TilePassPipeline useTilePipeline) { +void addPackPeel4LevelTilingBasedPassPipeline(OpPassManager &funcPassManager, + const std::string &pathToUkernels, + TilePassPipeline useTilePipeline, + Operation *rootOp) { + // Check if the root op is a 4D matmul-like operation. + bool is4DMatmulOp = is4DMatmulLikeOp(cast(rootOp)); + // First level tiling using scf.forall { AMDAIETileAndFuseOptions tileFuseOptions; @@ -323,18 +328,32 @@ void addPackPeel4LevelTilingBasedPassPipeline( funcPassManager.addPass(createCanonicalizerPass()); funcPassManager.addPass(createCSEPass()); - // First level packing - { - AMDAIEPackAndTransposeOptions packOptions; - packOptions.packLevel = 0; - funcPassManager.addPass(createAMDAIEPackAndTransposePass(packOptions)); + // First level pack or pad operation for data movement. + // For 2D matmul-like ops, pack operation is used to expand operands from 2D + // to 4D. For 4D matmul-like ops, pad operation is used to keep the original + // dimensions. + if (is4DMatmulOp) { + // First level pad + { + AMDAIEPadOptions padOptions; + padOptions.paddingLevel = 0; + funcPassManager.addPass(createAMDAIEPadPass(padOptions)); + } + funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createCSEPass()); + } else { + // First level packing + { + AMDAIEPackAndTransposeOptions packOptions; + packOptions.packLevel = 0; + funcPassManager.addPass(createAMDAIEPackAndTransposePass(packOptions)); + } + // Propagate pack ops for the elementwise op + funcPassManager.addPass(createAMDAIEPropagateDataLayoutPass()); + funcPassManager.addPass(createCanonicalizerPass()); + funcPassManager.addPass(createCSEPass()); } - // Propagate pack ops for the elementwise op - funcPassManager.addPass(createAMDAIEPropagateDataLayoutPass()); - funcPassManager.addPass(createCanonicalizerPass()); - funcPassManager.addPass(createCSEPass()); - // Promote the matmul output to shared memory { AMDAIEBufferizeToAllocationOptions bufferizeOptions; @@ -354,10 +373,11 @@ void addPackPeel4LevelTilingBasedPassPipeline( createAMDAIEBufferizeToAllocationPass(bufferizeOptions)); } - // Second level packing + // If the input is 4D matmul-like op, this is the first level of packing. + // Otherwise for 2D matmul-like op, it is the second level. { AMDAIEPackAndTransposeOptions packOptions; - packOptions.packLevel = 1; + packOptions.packLevel = is4DMatmulOp ? 0 : 1; funcPassManager.addPass(createAMDAIEPackAndTransposePass(packOptions)); } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h index dd4468f14..a01fe71a0 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.h @@ -55,7 +55,8 @@ void addPackPeelBasedPassPipeline(OpPassManager &passManager, /// 4 levels of tiling. void addPackPeel4LevelTilingBasedPassPipeline(OpPassManager &passManager, const std::string &pathToUkernels, - TilePassPipeline useTilePipeline); + TilePassPipeline useTilePipeline, + Operation *rootOp); /// Populates passes needed to lower the IR via a Pad-Pack based approach. void addPadPackBasedPassPipeline(OpPassManager &passManager, diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.cpp index 8ad6c0e4a..ead1c79b9 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.cpp @@ -148,14 +148,11 @@ static bool bodyMatcherForMatmulLikeOps(Value yieldVal, Block *body) { } /// Utility to check if the input generic op is a 2D matmul-like op. -static bool is2DMatmulLikeOp(linalg::LinalgOp linalgOp) { +bool is2DMatmulLikeOp(linalg::LinalgOp linalgOp) { // Check iterator types. - SmallVector matmulIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::reduction}; - SmallVector opIteratorTypes = - linalgOp.getIteratorTypesArray(); - if (matmulIteratorTypes != opIteratorTypes) return false; + unsigned numParallelLoops = linalgOp.getNumParallelLoops(); + unsigned numReductionLoops = linalgOp.getNumReductionLoops(); + if (numParallelLoops != 2 || numReductionLoops != 1) return false; // Check the number of inputs and results from indexing maps. ArrayAttr indexingMaps = linalgOp.getIndexingMaps(); @@ -174,15 +171,11 @@ static bool is2DMatmulLikeOp(linalg::LinalgOp linalgOp) { } /// Utility to check if the input generic op is a 4D matmul-like op. -static bool is4DMatmulLikeOp(linalg::LinalgOp linalgOp) { +bool is4DMatmulLikeOp(linalg::LinalgOp linalgOp) { // Check iterator types. - SmallVector matmulIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::reduction, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::reduction}; - SmallVector opIteratorTypes = - linalgOp.getIteratorTypesArray(); - if (matmulIteratorTypes != opIteratorTypes) return false; + unsigned numParallelLoops = linalgOp.getNumParallelLoops(); + unsigned numReductionLoops = linalgOp.getNumReductionLoops(); + if (numParallelLoops != 4 || numReductionLoops != 2) return false; // Check indexing maps. ArrayAttr indexingMaps = linalgOp.getIndexingMaps(); @@ -201,17 +194,11 @@ static bool is4DMatmulLikeOp(linalg::LinalgOp linalgOp) { } /// Utility to check if the input generic op is a 6D matmul-like op. -static bool is6DMatmulLikeOp(linalg::LinalgOp linalgOp) { +bool is6DMatmulLikeOp(linalg::LinalgOp linalgOp) { // Check iterator types. - SmallVector matmulIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::reduction, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::reduction, - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::reduction}; - SmallVector opIteratorTypes = - linalgOp.getIteratorTypesArray(); - if (matmulIteratorTypes != opIteratorTypes) return false; + unsigned numParallelLoops = linalgOp.getNumParallelLoops(); + unsigned numReductionLoops = linalgOp.getNumReductionLoops(); + if (numParallelLoops != 6 || numReductionLoops != 3) return false; // Check indexing maps. ArrayAttr indexingMaps = linalgOp.getIndexingMaps(); @@ -340,8 +327,9 @@ bool isMatmulInDefChain(Value operand) { /// Utility to identify if `linalgOp` is an elementwise operation with a /// matmul-like op upstream in its computation tree. -bool isMatmulProducerOfElementwise(linalg::LinalgOp linalgOp) { - if (!linalg::isElementwise(linalgOp) || isa(linalgOp)) { +bool isElementwiseWithMatmulProducer(linalg::LinalgOp linalgOp) { + if (!linalg::isElementwise(linalgOp) || + isa(linalgOp)) { return false; } // Check if any of the defining op is a matmul-like op. diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.h index f7dc69a42..663e89bf6 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.h @@ -67,13 +67,19 @@ bool isMatmulTransposeA(linalg::LinalgOp linalgOp); /// Utility to identify whether a linalg op is a matmul_transpose_b op. bool isMatmulTransposeB(linalg::LinalgOp linalgOp); +/// Utility to identify whether a linalg op is a 2D matmul-like op. +bool is2DMatmulLikeOp(linalg::LinalgOp linalgOp); + +/// Utility to identify whether a linalg op is a 4D matmul-like op. +bool is4DMatmulLikeOp(linalg::LinalgOp linalgOp); + /// Utility to identify if the input operand has matmul-like op in its /// def-chain. bool isMatmulInDefChain(Value operand); /// Utility to identify if `linalgOp` is an elementwise operation with a /// matmul-like op upstream in its computation tree. -bool isMatmulProducerOfElementwise(linalg::LinalgOp linalgOp); +bool isElementwiseWithMatmulProducer(linalg::LinalgOp linalgOp); /// Utility to convert a `uint32_t` value into a hex string. std::string utohexstr(uint32_t value, size_t width, bool header = true, diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_generic.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_generic.mlir index f8bbb02f5..b33071f1b 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_generic.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lowering_strategy_generic.mlir @@ -1,4 +1,5 @@ // RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-amdaie-lowering-strategy)' %s | FileCheck %s +// RUN: iree-opt --split-input-file --pass-pipeline='builtin.module(iree-amdaie-lowering-strategy{use-tile-pipeline=pack-peel-4-level-tiling})' %s | FileCheck %s --check-prefix=PACK-PEEL-4-LEVEL // Test generic version of matmul. @@ -115,3 +116,35 @@ module { return } } + +// ----- + +// Test generic version of matmul with 4d inputs and output. + +// PACK-PEEL-4-LEVEL{LITERAL}: #config = #iree_codegen.lowering_config +// PACK-PEEL-4-LEVEL{LITERAL}: #amdaie.packing_config +module { + func.func @mmt4d_dispatch_0_matmul_like_16x128x32x32x8x64_bf16xbf16xf32() { + %cst = arith.constant 0.000000e+00 : f32 + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(0) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(1) alignment(64) offset(%c0) flags("ReadOnly|Indirect") : !flow.dispatch.tensor> + %2 = hal.interface.binding.subspan layout(, #hal.pipeline.binding, #hal.pipeline.binding], flags = Indirect>) binding(2) alignment(64) offset(%c0) flags(Indirect) : !flow.dispatch.tensor> + %3 = flow.dispatch.tensor.load %0, offsets = [0, 0, 0, 0], sizes = [16, 8, 32, 64], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<16x8x32x64xbf16> + %4 = flow.dispatch.tensor.load %1, offsets = [0, 0, 0, 0], sizes = [128, 8, 64, 32], strides = [1, 1, 1, 1] : !flow.dispatch.tensor> -> tensor<128x8x64x32xbf16> + %5 = tensor.empty() : tensor<128x16x32x32xf32> + %6 = linalg.fill ins(%cst : f32) outs(%5 : tensor<128x16x32x32xf32>) -> tensor<128x16x32x32xf32> + // PACK-PEEL-4-LEVEL: linalg.generic + // PACK-PEEL-4-LEVEL-SAME: attrs = {lowering_config = #config, packing_config = #packingConfig} + %7 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d4, d2, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d4, d5, d3)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} ins(%3, %4 : tensor<16x8x32x64xbf16>, tensor<128x8x64x32xbf16>) outs(%6 : tensor<128x16x32x32xf32>) { + ^bb0(%in: bf16, %in_0: bf16, %out: f32): + %8 = arith.extf %in : bf16 to f32 + %9 = arith.extf %in_0 : bf16 to f32 + %10 = arith.mulf %8, %9 : f32 + %11 = arith.addf %out, %10 : f32 + linalg.yield %11 : f32 + } -> tensor<128x16x32x32xf32> + flow.dispatch.tensor.store %7, %2, offsets = [0, 0, 0, 0], sizes = [128, 16, 32, 32], strides = [1, 1, 1, 1] : tensor<128x16x32x32xf32> -> !flow.dispatch.tensor> + return + } +}