Skip to content

Commit

Permalink
[ObjectFifo][Ukernel] Update ukernel signature in mm.cc and lower-to-…
Browse files Browse the repository at this point in the history
…ukernel pass (nod-ai#624)

-- This commit updates the ukernel signature in mm.cc and the
corresponding func.call signature generated via
`--iree-amdaie-lower-to-ukernels` pass.
-- For ukernel to work faster we use template arguments for M x N x K.
The current ukernel expects LHS/RHS/OUT to always be `64x64`. This
commit adds another variant to deal with `32x32` since that's what we
generate via ObjectFifo.
-- It also adds an e2e CI test for the ObjectFifo Ukernel path.

Signed-off-by: Abhishek Varma <[email protected]>
  • Loading branch information
Abhishek-Varma authored Aug 1, 2024
1 parent 7f78462 commit cd91386
Show file tree
Hide file tree
Showing 4 changed files with 205 additions and 121 deletions.
13 changes: 13 additions & 0 deletions build_tools/ci/run_matmul_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,10 @@ bf16_i8_shapes_medium=(
'1536x2048x1536'
)

bf16_ukernel_shapes_small=(
'256x256x256'
)

run_matmul_test_on_shapes ${bf16_i8_shapes_small[@]} \
--name_prefix "small" \
--lower_to_aie_pipeline "objectFifo" \
Expand Down Expand Up @@ -826,6 +830,15 @@ run_matmul_test_on_shapes ${bf16_i8_shapes_medium[@]} \
--acc_type "i32" \
--num_repeat_runs "2"

run_matmul_test_on_shapes ${bf16_ukernel_shapes_small[@]} \
--name_prefix "small" \
--lower_to_aie_pipeline "objectFifo" \
--tile_pipeline "pack-peel" \
--lhs_rhs_type "bf16" \
--acc_type "f32" \
--num_repeat_runs "2" \
--use_ukernel "1"

###################################################################
# Chess tests
###################################################################
Expand Down
44 changes: 26 additions & 18 deletions compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ void matmul_vectorized(const T_in *__restrict pA, unsigned offsetA,
}
template <unsigned m, unsigned k, unsigned n>
void matmul_vectorized_4x8x4_bf16_bf16(const bfloat16 *__restrict pA,
void matmul_vectorized_4x8x4_bf16_bf16_bf16(const bfloat16 *__restrict pA,
unsigned offsetA,
const bfloat16 *__restrict pB,
unsigned offsetB,
Expand All @@ -280,7 +280,7 @@ void matmul_vectorized_4x8x4_bf16_bf16(const bfloat16 *__restrict pA,
}
template <unsigned m, unsigned k, unsigned n>
void matmul_vectorized_4x8x4_bf16_f32(const bfloat16 *__restrict pA,
void matmul_vectorized_4x8x4_bf16_bf16_f32(const bfloat16 *__restrict pA,
unsigned offsetA,
const bfloat16 *__restrict pB,
unsigned offsetB, float *__restrict pC,
Expand All @@ -297,26 +297,34 @@ void matmul_vectorized_4x8x4_bf16_f32(const bfloat16 *__restrict pA,
extern "C" {
#define combos(X) \
X(bfloat16, bf16, bfloat16, bf16, 4, 8, 4) \
X(bfloat16, bf16, float, f32, 4, 8, 4)
#define matmul_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \
mlir_type_out, r, s, t) \
void matmul_##mlir_type_in##_##mlir_type_out( \
ctype_in *a_in, unsigned offsetA, ctype_in *b_in, unsigned offsetB, \
ctype_out *c_out, unsigned offsetC) { \
matmul_vectorized_##r##x##s##x##t##_##mlir_type_in##_##mlir_type_out< \
64, 64, 64>(a_in, offsetA, b_in, offsetB, c_out, offsetC); \
#define matmul_combos(X, M, N, K) \
X(bfloat16, bf16, bfloat16, bf16, bfloat16, bf16, M, N, K, 4, 8, 4) \
X(bfloat16, bf16, bfloat16, bf16, float, f32, M, N, K, 4, 8, 4)
#define zero_fill_combos(X, M, N) \
X(bfloat16, bf16, M, N, N/2) \
X(float, f32, M, N, N/2)
#define matmul_vectorized_c_func(lhs_ctype_in, lhs_mlir_type_in, \
rhs_ctype_in, rhs_mlir_type_in, \
acc_ctype_out, acc_mlir_type_out, M, K, N, r, s, t) \
void matmul_##lhs_mlir_type_in##_##rhs_mlir_type_in##_##acc_mlir_type_out##_##M##x##K##x##N##_##r##x##s##x##t( \
lhs_ctype_in *a_in, unsigned offsetA, rhs_ctype_in *b_in, unsigned offsetB, \
acc_ctype_out *c_out, unsigned offsetC) { \
matmul_vectorized_##r##x##s##x##t##_##lhs_mlir_type_in##_##rhs_mlir_type_in##_##acc_mlir_type_out< \
M, K, N>(a_in, offsetA, b_in, offsetB, c_out, offsetC); \
}
#define zero_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \
mlir_type_out, r, s, t) \
void zero_##mlir_type_out(ctype_out *c_out, unsigned offsetC) { \
zero_vectorized<ctype_out, 64, 64, 32>(c_out, offsetC); \
#define zero_vectorized_c_func(ctype_out, mlir_type_out, M, N, r) \
void zero_##mlir_type_out##_##M##x##N(ctype_out *c_out, unsigned offsetC) { \
zero_vectorized<ctype_out, M, N, r>(c_out, offsetC); \
}
combos(matmul_vectorized_c_func) combos(zero_vectorized_c_func)
matmul_combos(matmul_vectorized_c_func, 32, 32, 32)
matmul_combos(matmul_vectorized_c_func, 64, 64, 64)
zero_fill_combos(zero_vectorized_c_func, 32, 32)
zero_fill_combos(zero_vectorized_c_func, 64, 64)
} // extern "C"
)chess"
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,31 @@ static FnNameAndDefAttrs getFnNameAndDefAttrs(RewriterBase &rewriter,
/// ================== SAME UTILITIES AS IREE LLVMCPU ====================
/// ======================================================================

/// Utility to fetch the element type as string.
static std::string typeToString(Type type) {
std::string typeStr;
llvm::raw_string_ostream rso(typeStr);
type.print(rso);
return typeStr;
}

/// We need to fetch the tiling at M, N and K for the input tensors along with
/// the intrinsics that the ukernel supports. The following utility helps us
/// fetch the same.
static std::tuple<int, int, int, int> getTilingInfo(ShapedType shapedType) {
SmallVector<int64_t> shapeVec(shapedType.getShape());
int index = 0;
if (shapeVec.size() == 6) {
index = 2;
} else {
assert(shapeVec.size() == 4 &&
"lhs/rhs/out shape should have rank either 4 or 6");
}
int M = shapeVec[index + 1] * shapeVec[index + 2];
int N = shapeVec[index] * shapeVec[index + 3];
return {M, N, shapeVec[index + 2], shapeVec[index + 3]};
}

/// Matches a linalg.generic operation which is basically a tiled matmul and
/// converts it into a iree_codegen.ukernel."iree_amdaie_uk_matmul" operation,
/// that is later lowered into a call to the microkernel.
Expand All @@ -93,33 +118,29 @@ static FailureOr<IREE::Codegen::UKernelOpInterface> matchDAGForUKernel(
Value lhs = op.getDpsInputOperand(0)->get();
Value rhs = op.getDpsInputOperand(1)->get();
Value out = op.getDpsInitOperand(0)->get();
auto lhsType = llvm::cast<ShapedType>(lhs.getType());
auto rhsType = llvm::cast<ShapedType>(rhs.getType());
auto outType = llvm::cast<ShapedType>(out.getType());
Type lhsElemType = llvm::cast<ShapedType>(lhs.getType()).getElementType();
Type rhsElemType = llvm::cast<ShapedType>(rhs.getType()).getElementType();
Type lhsElemType = lhsType.getElementType();
Type rhsElemType = rhsType.getElementType();
Type outElemType = outType.getElementType();

std::string inputOutputElemType = "";
if (lhsElemType.isSignlessInteger(32) && rhsElemType.isSignlessInteger(32) &&
outElemType.isSignlessInteger(32)) {
inputOutputElemType = "i32_i32";
} else if (lhsElemType.isBF16() && rhsElemType.isBF16() &&
outElemType.isBF16()) {
inputOutputElemType = "bf16_bf16";
} else if (lhsElemType.isBF16() && rhsElemType.isBF16() &&
outElemType.isF32()) {
inputOutputElemType = "bf16_f32";
} else if (lhsElemType.isF32() && rhsElemType.isF32() &&
outElemType.isF32()) {
inputOutputElemType = "f32_f32";
} else {
return rewriter.notifyMatchFailure(
op, "unsupported combination of element types for microkernel");
}
// Tiling for M x K x N as well as the corresponding inner tiling intrinsics
// r x s x t.
int M, N, K, r, s, t;
std::tie(M, K, r, s) = getTilingInfo(lhsType);
std::tie(std::ignore, N, std::ignore, t) = getTilingInfo(outType);
std::string inputOutputElemTypeAndSize =
typeToString(lhsElemType) + "_" + typeToString(rhsElemType) + "_" +
typeToString(outElemType) + "_" + std::to_string(M) + "x" +
std::to_string(N) + "x" + std::to_string(K) + "_" + std::to_string(r) +
"x" + std::to_string(s) + "x" + std::to_string(t);

Location loc = op.getLoc();

auto fn = getFnNameAndDefAttrs(rewriter, ukernelName, inputOutputElemType,
pathToUkernels, "mm.o");
auto fn =
getFnNameAndDefAttrs(rewriter, ukernelName, inputOutputElemTypeAndSize,
pathToUkernels, "mm.o");

// Create UKernel for AMD-AIE.
auto genericMicroKernelOp = rewriter.create<IREE::Codegen::UKernelGenericOp>(
Expand Down Expand Up @@ -148,19 +169,16 @@ static FailureOr<IREE::Codegen::UKernelOpInterface> matchDAGForUKernel(
auto outType = llvm::cast<ShapedType>(output.getType());
Type outElemType = outType.getElementType();

std::string elemType = "";
if (outElemType.isBF16()) {
elemType = "bf16";
} else if (outElemType.isF32()) {
elemType = "f32";
} else {
return rewriter.notifyMatchFailure(
op, "unsupported combination of element types for microkernel");
}
// Tiling for M x N as well as the corresponding inner tiling intrinsics r x
// t.
int M, N, r, t;
std::tie(M, N, r, t) = getTilingInfo(outType);
std::string elemTypeAndSize = typeToString(outElemType) + "_" +
std::to_string(M) + "x" + std::to_string(N);

Location loc = op.getLoc();

auto fn = getFnNameAndDefAttrs(rewriter, ukernelName, elemType,
auto fn = getFnNameAndDefAttrs(rewriter, ukernelName, elemTypeAndSize,
pathToUkernels, "mm.o");

// Create UKernel for AMD-AIE.
Expand Down
Loading

0 comments on commit cd91386

Please sign in to comment.