Skip to content

Commit

Permalink
Support matmul 4d inputs on pack-peel-4-level-tiling (nod-ai#1098)
Browse files Browse the repository at this point in the history
This PR adds basic support for matmul with 4d inputs and output on
pack-peel-4-level-tiling.

In theory, the tiling strategy and pipeline should support different
layouts of matmul 4d operations. However, the order of the input dims
(both inner and outer dims) are crucial for the correct compilation and
results.

To ensure the correctness and for comparison purpose, this PR only adds
a test that corresponds to the standard matmul, which means the order of
the input dims for this operation corresponds to the L2 shapes of the
matmul op after the first level packing, i.e.,

```C += matmul4d(A,B) where A:MxKxM0xK0, B:NxKxK0xN0, C:NxMxM0xN0```
  
The test class and instance added in run.py is preliminary and for experimental purpose. Generalization of the test class will be addressed as follow-ups.

Runtime comparison:
On Phoenix runner `matmul_512_4096_512_bf16_f32 : 1141 us vs matmul4d_16_128_8_bf16_f32: 998 us`
On Strix runner `matmul_512_4096_512_i8_i32 : 806 us vs matmul4d_16_128_8_i8_i32: 491 us`
  • Loading branch information
yzhang93 authored Feb 14, 2025
1 parent 683406d commit 0c21e54
Show file tree
Hide file tree
Showing 12 changed files with 282 additions and 97 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// input ${M}x${K}x32x64x${TYPE1}
// input ${N}x${K}x64x32x${TYPE1}

func.func @matmul4d(%arg0: tensor<${M}x${K}x32x64x${TYPE1}>, %arg1: tensor<${N}x${K}x64x32x${TYPE1}>) -> tensor<${N}x${M}x32x32x${TYPE2}> {
%cst = arith.constant ${ZERO} : ${TYPE2}
%0 = tensor.empty() : tensor<${N}x${M}x32x32x${TYPE2}>
%1 = linalg.fill ins(%cst : ${TYPE2}) outs(%0 : tensor<${N}x${M}x32x32x${TYPE2}>) -> tensor<${N}x${M}x32x32x${TYPE2}>
%2 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d2, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<${M}x${K}x32x64x${TYPE1}>, tensor<${N}x${K}x64x32x${TYPE1}>) outs(%1 : tensor<${N}x${M}x32x32x${TYPE2}>) {
^bb0(%in: ${TYPE1}, %in_1: ${TYPE1}, %out: ${TYPE2}):
%12 = ${EXT} %in : ${TYPE1} to ${TYPE2}
%13 = ${EXT} %in_1 : ${TYPE1} to ${TYPE2}
%14 = ${MUL} %12, %13 : ${TYPE2}
%15 = ${ADD} %out, %14 : ${TYPE2}
linalg.yield %15 : ${TYPE2}
} -> tensor<${N}x${M}x32x32x${TYPE2}>
return %2 : tensor<${N}x${M}x32x32x${TYPE2}>
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import sys
import re
import os


def get_higher_order_element_type(element_type):
Expand Down Expand Up @@ -29,6 +27,8 @@ def generate_matmul_test(output_fn, input_fn, m, n, k, lhs_rhs_type, acc_type, b
acc_is_int = acc_type[0] == "i"
replace["ZERO"] = 0 if acc_is_int else 0.0
replace["ADD"] = "arith.addi" if acc_is_int else "arith.addf"
replace["MUL"] = "arith.muli" if acc_is_int else "arith.mulf"
replace["EXT"] = "arith.extsi" if acc_is_int else "arith.extf"

key_map = map(lambda s: "${" + s + "}", replace.keys())
key_map_escaped = map(re.escape, key_map)
Expand Down
111 changes: 101 additions & 10 deletions build_tools/ci/cpu_comparison/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from input_generator import (
generate_inputs,
verify_determinism,
load_input,
get_output_type,
np_from_binfile,
)
Expand Down Expand Up @@ -486,6 +485,59 @@ def _execute(self, config):
return self.vs_cpu(config)


class Matmul4d(BaseMatmul):
"""
A test of linalg.generic with 4d inputs and output implementing form:
C += matmul4d(A,B) where A:MxKxM0xK0, B:NxKxK0xN0, C:NxMxM0xN0
Note that the outer dims for this operation are transposed to make sure
successful compilation through LogicalObjectFifo pipeline.
For comparison purpose, the input values of inner dims M0/N0/K0 are
fixed as 32/32/64 currently.
TODO(vivian): Generalize the class and the template.
"""

def __init__(
self,
M,
N,
K,
input_type,
acc_type,
additional_labels=None,
n_kernel_runs=1,
test_params=None,
):
super().__init__(
name=f"matmul4d_{M}_{N}_{K}_{input_type}_{acc_type}",
test_params=test_params,
M=M,
N=N,
K=K,
input_type=input_type,
acc_type=acc_type,
function_name="matmul4d",
n_kernel_runs=n_kernel_runs,
)
self.labels.append("Matmul4d")
if additional_labels:
self.labels += additional_labels
if self.run_benchmark:
self.aie_compilation_flags += [
"--iree-amdaie-enable-infinite-loop-around-core-block=true"
]
self.labels.append("Matmul4dBenchmark")

def _execute(self, config):
matmul_template_dir = config.file_dir / "matmul_template"
template_name = matmul_template_dir / "matmul4d_MxKxM0xK0_NxKxK0xN0.mlir"
self.generate(config, template_name)
if self.run_benchmark:
return self.benchmark(config)

return self.vs_cpu(config)


class MatmulThinBias(BaseMatmul):
"""
A test of the form matmul(A,B) + C where A:MxK, B:KxN, C:N
Expand All @@ -502,9 +554,11 @@ def __init__(
):
super().__init__(
name=f"matmul_thin_bias_{M}_{N}_{K}_{input_type}_{acc_type}",
test_params=test_params
if test_params is not None
else TestParams(lower_to_aie_pipeline="air"),
test_params=(
test_params
if test_params is not None
else TestParams(lower_to_aie_pipeline="air")
),
M=M,
N=N,
K=K,
Expand Down Expand Up @@ -543,9 +597,11 @@ def __init__(
):
super().__init__(
name=f"matmul_full_bias_{M}_{N}_{K}_{input_type}_{acc_type}",
test_params=test_params
if test_params is not None
else TestParams(lower_to_aie_pipeline="air"),
test_params=(
test_params
if test_params is not None
else TestParams(lower_to_aie_pipeline="air")
),
M=M,
N=N,
K=K,
Expand All @@ -565,8 +621,7 @@ def _execute(self, config):
"--iree-amdaie-num-cols=2",
]
)
self.vs_cpu(config)
return True
return self.vs_cpu(config)


class BatchMatmul(BaseMatmul):
Expand Down Expand Up @@ -2194,6 +2249,21 @@ def __init__(self):
"transpose_b": False,
"tile_pipeline": "pack-peel-4-level-tiling",
},
# matmul4d test where the input M/N/K are outer dim values.
# The total input values correspond to a standard matmul
# from the above test are M:512, N:4096, K:512.
{
"M": 16,
"N": 128,
"K": 8,
"use_ukernel": True,
"peano_opt_level": 3,
"outline": "balanced",
"transpose_a": False,
"transpose_b": False,
"matmul4d": True,
"tile_pipeline": "pack-peel-4-level-tiling",
},
# Test where the compute is omitted, this should help triangulate
# how much performance gain can be obtained with better matmul
# on core vs data movement.
Expand Down Expand Up @@ -2257,6 +2327,24 @@ def __init__(self):
"tile_pipeline": "pack-peel-4-level-tiling",
"run_on_target": "npu4",
},
# matmul4d test where the input M/N/K are outer dim values.
# The total input values correspond to a standard matmul
# from the above test are M:512, N:4096, K:512.
{
"M": 16,
"N": 128,
"K": 8,
"in_dtype": "i8",
"out_dtype": "i32",
"use_ukernel": True,
"peano_opt_level": 3,
"outline": "all",
"transpose_a": False,
"transpose_b": False,
"matmul4d": True,
"tile_pipeline": "pack-peel-4-level-tiling",
"run_on_target": "npu4",
},
{
"M": 512,
"N": 4096,
Expand Down Expand Up @@ -2301,6 +2389,7 @@ def __init__(self):
transpose_a = test["transpose_a"]
transpose_b = test["transpose_b"]
tile_pipeline = test["tile_pipeline"]
matmul4d = test["matmul4d"] if "matmul4d" in test else False
run_on_target = (
test["run_on_target"] if "run_on_target" in test else "npu1_4col"
)
Expand Down Expand Up @@ -2343,7 +2432,9 @@ def __init__(self):
else:
name_suffix += "_outline"

if (transpose_a, transpose_b) == (False, False):
if matmul4d:
TestClass = Matmul4d
elif (transpose_a, transpose_b) == (False, False):
TestClass = Matmul
elif (transpose_a, transpose_b) == (True, False):
TestClass = MatmulTransposeA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ void AMDAIEBufferizeToAllocationPass::runOnOperation() {
!linalg::isaConvolutionOpInterface(op)) {
return WalkResult::advance();
}
if (isa<linalg::FillOp>(op)) {
if (isa<linalg::FillOp, linalg::CopyOp>(op)) {
return WalkResult::advance();
}
// Use flag `bufferizeElementwise` to indicate whether the target for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void AMDAIEFuseConsumerIntoLoopPass::runOnOperation() {
// Check if there is matmul-elementwise fusion opportunity. If so, overwrite
// the `fuseDepth` to be 2.
funcOp->walk<WalkOrder::PostOrder, ReverseIterator>([&](linalg::LinalgOp op) {
if (isMatmulProducerOfElementwise(op)) {
if (isElementwiseWithMatmulProducer(op)) {
fuseDepth = 2;
return WalkResult::interrupt();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,13 @@ class AMDAIELowerExecutableTargetPass
};
} // namespace

static Operation *getRootOp(FunctionOpInterface funcOp) {
SmallVector<Operation *> computeOps = getComputeOps(funcOp);
FailureOr<Operation *> rootOp = getRootOperation(computeOps);
assert(succeeded(rootOp) && "Pipeline requires a root operation");
return rootOp.value();
}

void AMDAIELowerExecutableTargetPass::runOnOperation() {
auto funcOp = getOperation();
auto target = IREE::HAL::ExecutableTargetAttr::lookup(funcOp);
Expand All @@ -82,7 +89,7 @@ void AMDAIELowerExecutableTargetPass::runOnOperation() {
TilePassPipeline::PackPeel4LevelTilingPipeline) {
addPackPeel4LevelTilingBasedPassPipeline(
executableLoweringPipeline, pathToUkernels,
TilePassPipeline::PackPeel4LevelTilingPipeline);
TilePassPipeline::PackPeel4LevelTilingPipeline, getRootOp(funcOp));
} else if (useTilePipeline == TilePassPipeline::PadPackPipeline) {
addPadPackBasedPassPipeline(executableLoweringPipeline, pathToUkernels,
enableVectorizationPasses,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,47 +468,62 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
MLIRContext *context = entryPointFn.getContext();
unsigned numLoops = linalgOp.getNumLoops();

// Pack level => 1.
SmallVector<int64_t> packedSizesL0(numLoops, 0);
packedSizesL0[mDims.back()] = packPeelTiling.m0Pack;
packedSizesL0[nDims.back()] = packPeelTiling.n0Pack;
packedSizesL0[kDims.back()] = packPeelTiling.k0Pack;

// For matmul, transpose B matrix from [K N n k] to [N K k n]
// For matmul_transpose_b, we don't have to transpose the B matrix,
// since it is already [N K n k]
SmallVector<int64_t> transposePackIndices = {0, 1, 2};
// There is no corresponding unpack for the specified pack operation
// 0 is used when unpack is empty
SmallVector<bool> unpackEmpty = {false, false, true};
bool isBatchMatmul = isa<linalg::BatchMatmulOp>(linalgOp);
SmallVector<int64_t> innerPermA = setInnerPermA(isMatmulTransposeA(linalgOp));
SmallVector<int64_t> innerPermB = setInnerPermB(isMatmulTransposeB(linalgOp));
SmallVector<SmallVector<int64_t>> innerPerm = {
innerPermA, innerPermB, {0, 1}};
bool isBatchMatmul = isa<linalg::BatchMatmulOp>(linalgOp);
SmallVector<int64_t> outerPermA =
setOuterPermA(isMatmulTransposeA(linalgOp), isBatchMatmul);
SmallVector<int64_t> outerPermB =
setOuterPermB(isMatmulTransposeB(linalgOp), isBatchMatmul);
SmallVector<SmallVector<int64_t>> outerPerm = {outerPermA, outerPermB};
// Add outer permutation for unpack. NOTE: This currently fails for some
// tests in the AIR pipeline.
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
outerPerm.push_back({0, 2, 1});
} else {
outerPerm.push_back({1, 0});
}

auto packingConfigLevel0Attr = getPackingConfigPackingLevelAttr(
context, packedSizesL0, transposePackIndices, unpackEmpty, innerPerm,
outerPerm);
SmallVector<int64_t> transposePackIndices;
SmallVector<bool> unpackEmpty;
SmallVector<SmallVector<int64_t>> innerPerm;
SmallVector<SmallVector<int64_t>> outerPerm;
SmallVector<PackingConfigPackingLevelAttr> packingConfigLevelsVal;

// Pack level => 1.
// For 2D matmul-like ops, the first level is to pack operands from 2D to 4D.
// If the input is a 4D matmul-like op, this level of packing is not needed.
bool is2DMatmulLike = is2DMatmulLikeOp(linalgOp) || isBatchMatmul;
if (is2DMatmulLike) {
SmallVector<int64_t> packedSizesL0(numLoops, 0);
packedSizesL0[mDims.back()] = packPeelTiling.m0Pack;
packedSizesL0[nDims.back()] = packPeelTiling.n0Pack;
packedSizesL0[kDims.back()] = packPeelTiling.k0Pack;

transposePackIndices = {0, 1, 2};
// There is no corresponding unpack for the specified pack operation
// 0 is used when unpack is empty
unpackEmpty = {false, false, true};
innerPerm = {innerPermA, innerPermB, {0, 1}};
outerPerm = {outerPermA, outerPermB};
// Add outer permutation for unpack. NOTE: This currently fails for some
// tests in the AIR pipeline.
if (isBatchMatmul) {
outerPerm.push_back({0, 2, 1});
} else {
outerPerm.push_back({1, 0});
}

auto packingConfigLevel0Attr = getPackingConfigPackingLevelAttr(
context, packedSizesL0, transposePackIndices, unpackEmpty, innerPerm,
outerPerm);
packingConfigLevelsVal.push_back(packingConfigLevel0Attr);
}

// Pack level => 2.
// The number of loops have increased by 3 due to the first level pack.
SmallVector<int64_t> packedSizesL1(numLoops + 3, 0);
packedSizesL1[mDims.back() + 3] = packPeelTiling.m1Pack;
packedSizesL1[nDims.back() + 3] = packPeelTiling.n1Pack;
packedSizesL1[kDims.back() + 3] = packPeelTiling.k1Pack;
// If the first level pack exists (for 2D matmul-like ops), the number of
// packed dimensions should increase by 3, otherwise keep the original
// number of loops.
unsigned numPackedDims = is2DMatmulLike ? numLoops + 3 : numLoops;
unsigned mIdx = is2DMatmulLike ? mDims.back() + 3 : mDims.back();
unsigned nIdx = is2DMatmulLike ? nDims.back() + 3 : nDims.back();
unsigned kIdx = is2DMatmulLike ? kDims.back() + 3 : kDims.back();
SmallVector<int64_t> packedSizesL1(numPackedDims, 0);
packedSizesL1[mIdx] = packPeelTiling.m1Pack;
packedSizesL1[nIdx] = packPeelTiling.n1Pack;
packedSizesL1[kIdx] = packPeelTiling.k1Pack;

// Transpose A matrix from [M K m k m0 k0] to [M K k m m0 k0]
// Transpose C matrix from [M N m n m0 n0] to [M N n m m0 n0]
Expand All @@ -519,17 +534,16 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
// Only the third pack operation has a corresponding unpack operation
unpackEmpty = {false, false, true};
innerPerm = {innerPermA, innerPermB, {0, 1}};
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
if (isBatchMatmul) {
outerPerm = {{0, 1, 2, 4, 3}, {0, 1, 2, 4, 3}, {0, 1, 2, 4, 3}};
} else {
outerPerm = {{0, 1, 3, 2}, {0, 1, 3, 2}, {0, 1, 3, 2}};
}
auto packingConfigLevel1Attr = getPackingConfigPackingLevelAttr(
context, packedSizesL1, transposePackIndices, unpackEmpty, innerPerm,
outerPerm);
packingConfigLevelsVal.push_back(packingConfigLevel1Attr);

SmallVector<PackingConfigPackingLevelAttr> packingConfigLevelsVal = {
packingConfigLevel0Attr, packingConfigLevel1Attr};
auto packingConfigLevels =
PackingConfigPackingLevelsAttr::get(context, packingConfigLevelsVal);
auto config = PackingConfigAttr::get(context, packingConfigLevels);
Expand All @@ -550,14 +564,22 @@ static LogicalResult setRootConfigForPackPeel4LevelTilingPipeline(
bool fitsInL2 = (l2SizeA + l2SizeB + l2SizeInit) <
(deviceModel.getMemTileSizeInBytes() * numCols);
int64_t scaleL0 = !isBatchMatmul && fitsInL2 ? 2 : 1;
int64_t m0Tile = packPeelTiling.M0 * scaleL0;
int64_t n0Tile = packPeelTiling.N0 * scaleL0;

SmallVector<int64_t> tileSizeLevel0(numLoops, 0);
if (isa<linalg::BatchMatmulOp>(linalgOp)) {
if (isBatchMatmul) {
assert(!batchDims.empty() && "expected batch dims not empty");
tileSizeLevel0[batchDims[0]] = 1;
}
tileSizeLevel0[mDims[0]] = packPeelTiling.M0 * scaleL0;
tileSizeLevel0[nDims[0]] = packPeelTiling.N0 * scaleL0;
// For 4D matmul-like ops, only tile the outer dims.
// outer_tile_size = total_tile_size / inner_dim_size
if (is4DMatmulLikeOp(linalgOp)) {
m0Tile /= maybeInputDimsAndSizes.value().mSizes.back();
n0Tile /= maybeInputDimsAndSizes.value().nSizes.back();
}
tileSizeLevel0[mDims[0]] = m0Tile;
tileSizeLevel0[nDims[0]] = n0Tile;

SmallVector<int64_t> tileSizeLevel1(numLoops, 0);
tileSizeLevel1[mDims[0]] = numRows;
Expand Down
Loading

0 comments on commit 0c21e54

Please sign in to comment.