diff --git a/build_tools/ci/run_matmul_test.sh b/build_tools/ci/run_matmul_test.sh index 2a0b7003a..77f0fba25 100755 --- a/build_tools/ci/run_matmul_test.sh +++ b/build_tools/ci/run_matmul_test.sh @@ -793,6 +793,10 @@ bf16_i8_shapes_medium=( '1536x2048x1536' ) +bf16_ukernel_shapes_small=( + '256x256x256' +) + run_matmul_test_on_shapes ${bf16_i8_shapes_small[@]} \ --name_prefix "small" \ --lower_to_aie_pipeline "objectFifo" \ @@ -826,6 +830,15 @@ run_matmul_test_on_shapes ${bf16_i8_shapes_medium[@]} \ --acc_type "i32" \ --num_repeat_runs "2" +run_matmul_test_on_shapes ${bf16_ukernel_shapes_small[@]} \ + --name_prefix "small" \ + --lower_to_aie_pipeline "objectFifo" \ + --tile_pipeline "pack-peel" \ + --lhs_rhs_type "bf16" \ + --acc_type "f32" \ + --num_repeat_runs "2" \ + --use_ukernel "1" + ################################################################### # Chess tests ################################################################### diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm.cc b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm.cc index f59ae57e8..e8cc072fb 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm.cc +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/mm.cc @@ -263,7 +263,7 @@ void matmul_vectorized(const T_in *__restrict pA, unsigned offsetA, } template -void matmul_vectorized_4x8x4_bf16_bf16(const bfloat16 *__restrict pA, +void matmul_vectorized_4x8x4_bf16_bf16_bf16(const bfloat16 *__restrict pA, unsigned offsetA, const bfloat16 *__restrict pB, unsigned offsetB, @@ -280,7 +280,7 @@ void matmul_vectorized_4x8x4_bf16_bf16(const bfloat16 *__restrict pA, } template -void matmul_vectorized_4x8x4_bf16_f32(const bfloat16 *__restrict pA, +void matmul_vectorized_4x8x4_bf16_bf16_f32(const bfloat16 *__restrict pA, unsigned offsetA, const bfloat16 *__restrict pB, unsigned offsetB, float *__restrict pC, @@ -297,26 +297,34 @@ void matmul_vectorized_4x8x4_bf16_f32(const bfloat16 *__restrict pA, extern "C" { -#define combos(X) \ - X(bfloat16, bf16, bfloat16, bf16, 4, 8, 4) \ - X(bfloat16, bf16, float, f32, 4, 8, 4) - -#define matmul_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \ - mlir_type_out, r, s, t) \ - void matmul_##mlir_type_in##_##mlir_type_out( \ - ctype_in *a_in, unsigned offsetA, ctype_in *b_in, unsigned offsetB, \ - ctype_out *c_out, unsigned offsetC) { \ - matmul_vectorized_##r##x##s##x##t##_##mlir_type_in##_##mlir_type_out< \ - 64, 64, 64>(a_in, offsetA, b_in, offsetB, c_out, offsetC); \ +#define matmul_combos(X, M, N, K) \ + X(bfloat16, bf16, bfloat16, bf16, bfloat16, bf16, M, N, K, 4, 8, 4) \ + X(bfloat16, bf16, bfloat16, bf16, float, f32, M, N, K, 4, 8, 4) + +#define zero_fill_combos(X, M, N) \ + X(bfloat16, bf16, M, N, N/2) \ + X(float, f32, M, N, N/2) + +#define matmul_vectorized_c_func(lhs_ctype_in, lhs_mlir_type_in, \ + rhs_ctype_in, rhs_mlir_type_in, \ + acc_ctype_out, acc_mlir_type_out, M, K, N, r, s, t) \ + void matmul_##lhs_mlir_type_in##_##rhs_mlir_type_in##_##acc_mlir_type_out##_##M##x##K##x##N##_##r##x##s##x##t( \ + lhs_ctype_in *a_in, unsigned offsetA, rhs_ctype_in *b_in, unsigned offsetB, \ + acc_ctype_out *c_out, unsigned offsetC) { \ + matmul_vectorized_##r##x##s##x##t##_##lhs_mlir_type_in##_##rhs_mlir_type_in##_##acc_mlir_type_out< \ + M, K, N>(a_in, offsetA, b_in, offsetB, c_out, offsetC); \ } -#define zero_vectorized_c_func(ctype_in, mlir_type_in, ctype_out, \ - mlir_type_out, r, s, t) \ - void zero_##mlir_type_out(ctype_out *c_out, unsigned offsetC) { \ - zero_vectorized(c_out, offsetC); \ +#define zero_vectorized_c_func(ctype_out, mlir_type_out, M, N, r) \ + void zero_##mlir_type_out##_##M##x##N(ctype_out *c_out, unsigned offsetC) { \ + zero_vectorized(c_out, offsetC); \ } -combos(matmul_vectorized_c_func) combos(zero_vectorized_c_func) +matmul_combos(matmul_vectorized_c_func, 32, 32, 32) +matmul_combos(matmul_vectorized_c_func, 64, 64, 64) + +zero_fill_combos(zero_vectorized_c_func, 32, 32) +zero_fill_combos(zero_vectorized_c_func, 64, 64) } // extern "C" )chess" diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToUKernels.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToUKernels.cpp index bc039a85b..ae2f0c8ee 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToUKernels.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/AMDAIELowerToUKernels.cpp @@ -80,6 +80,31 @@ static FnNameAndDefAttrs getFnNameAndDefAttrs(RewriterBase &rewriter, /// ================== SAME UTILITIES AS IREE LLVMCPU ==================== /// ====================================================================== +/// Utility to fetch the element type as string. +static std::string typeToString(Type type) { + std::string typeStr; + llvm::raw_string_ostream rso(typeStr); + type.print(rso); + return typeStr; +} + +/// We need to fetch the tiling at M, N and K for the input tensors along with +/// the intrinsics that the ukernel supports. The following utility helps us +/// fetch the same. +static std::tuple getTilingInfo(ShapedType shapedType) { + SmallVector shapeVec(shapedType.getShape()); + int index = 0; + if (shapeVec.size() == 6) { + index = 2; + } else { + assert(shapeVec.size() == 4 && + "lhs/rhs/out shape should have rank either 4 or 6"); + } + int M = shapeVec[index + 1] * shapeVec[index + 2]; + int N = shapeVec[index] * shapeVec[index + 3]; + return {M, N, shapeVec[index + 2], shapeVec[index + 3]}; +} + /// Matches a linalg.generic operation which is basically a tiled matmul and /// converts it into a iree_codegen.ukernel."iree_amdaie_uk_matmul" operation, /// that is later lowered into a call to the microkernel. @@ -93,33 +118,29 @@ static FailureOr matchDAGForUKernel( Value lhs = op.getDpsInputOperand(0)->get(); Value rhs = op.getDpsInputOperand(1)->get(); Value out = op.getDpsInitOperand(0)->get(); + auto lhsType = llvm::cast(lhs.getType()); + auto rhsType = llvm::cast(rhs.getType()); auto outType = llvm::cast(out.getType()); - Type lhsElemType = llvm::cast(lhs.getType()).getElementType(); - Type rhsElemType = llvm::cast(rhs.getType()).getElementType(); + Type lhsElemType = lhsType.getElementType(); + Type rhsElemType = rhsType.getElementType(); Type outElemType = outType.getElementType(); - std::string inputOutputElemType = ""; - if (lhsElemType.isSignlessInteger(32) && rhsElemType.isSignlessInteger(32) && - outElemType.isSignlessInteger(32)) { - inputOutputElemType = "i32_i32"; - } else if (lhsElemType.isBF16() && rhsElemType.isBF16() && - outElemType.isBF16()) { - inputOutputElemType = "bf16_bf16"; - } else if (lhsElemType.isBF16() && rhsElemType.isBF16() && - outElemType.isF32()) { - inputOutputElemType = "bf16_f32"; - } else if (lhsElemType.isF32() && rhsElemType.isF32() && - outElemType.isF32()) { - inputOutputElemType = "f32_f32"; - } else { - return rewriter.notifyMatchFailure( - op, "unsupported combination of element types for microkernel"); - } + // Tiling for M x K x N as well as the corresponding inner tiling intrinsics + // r x s x t. + int M, N, K, r, s, t; + std::tie(M, K, r, s) = getTilingInfo(lhsType); + std::tie(std::ignore, N, std::ignore, t) = getTilingInfo(outType); + std::string inputOutputElemTypeAndSize = + typeToString(lhsElemType) + "_" + typeToString(rhsElemType) + "_" + + typeToString(outElemType) + "_" + std::to_string(M) + "x" + + std::to_string(N) + "x" + std::to_string(K) + "_" + std::to_string(r) + + "x" + std::to_string(s) + "x" + std::to_string(t); Location loc = op.getLoc(); - auto fn = getFnNameAndDefAttrs(rewriter, ukernelName, inputOutputElemType, - pathToUkernels, "mm.o"); + auto fn = + getFnNameAndDefAttrs(rewriter, ukernelName, inputOutputElemTypeAndSize, + pathToUkernels, "mm.o"); // Create UKernel for AMD-AIE. auto genericMicroKernelOp = rewriter.create( @@ -148,19 +169,16 @@ static FailureOr matchDAGForUKernel( auto outType = llvm::cast(output.getType()); Type outElemType = outType.getElementType(); - std::string elemType = ""; - if (outElemType.isBF16()) { - elemType = "bf16"; - } else if (outElemType.isF32()) { - elemType = "f32"; - } else { - return rewriter.notifyMatchFailure( - op, "unsupported combination of element types for microkernel"); - } + // Tiling for M x N as well as the corresponding inner tiling intrinsics r x + // t. + int M, N, r, t; + std::tie(M, N, r, t) = getTilingInfo(outType); + std::string elemTypeAndSize = typeToString(outElemType) + "_" + + std::to_string(M) + "x" + std::to_string(N); Location loc = op.getLoc(); - auto fn = getFnNameAndDefAttrs(rewriter, ukernelName, elemType, + auto fn = getFnNameAndDefAttrs(rewriter, ukernelName, elemTypeAndSize, pathToUkernels, "mm.o"); // Create UKernel for AMD-AIE. diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lower_to_ukernel.mlir b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lower_to_ukernel.mlir index c915a0af5..2dcb7b007 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lower_to_ukernel.mlir +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/test/lower_to_ukernel.mlir @@ -28,8 +28,8 @@ func.func @disabled_ukernel(%arg0 : tensor, %arg1 : tensor, %arg1 : tensor, - %arg2 : tensor) -> tensor attributes { +func.func @generic_matmul_i32i32i32_pad_pack(%arg0 : tensor<8x16x4x8xi32>, %arg1 : tensor<16x8x8x4xi32>, + %arg2 : tensor<16x16x4x4xi32>) -> tensor<16x16x4x4xi32> attributes { hal.executable.target = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "all"}> } { %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>, @@ -37,23 +37,23 @@ func.func @generic_matmul_i32i32i32_pad_pack(%arg0 : tensor, %arg1 affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] - } ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) + } ins(%arg0, %arg1 : tensor<8x16x4x8xi32>, tensor<16x8x8x4xi32>) + outs(%arg2 : tensor<16x16x4x4xi32>) { ^bb0(%in: i32, %in_9: i32, %out: i32): %22 = arith.muli %in, %in_9 : i32 %23 = arith.addi %out, %22 : i32 linalg.yield %23 : i32 - } -> tensor + } -> tensor<16x16x4x4xi32> - return %0 : tensor + return %0 : tensor<16x16x4x4xi32> } // CHECK: func @generic_matmul_i32i32i32_pad_pack( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8x16x4x8xi32> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<16x8x8x4xi32> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<16x16x4x4xi32>) // CHECK-NOT: linalg.generic -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_i32_i32" +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_i32_i32_i32_64x64x64_4x8x4" // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK-SAME: outs(%[[ARG2]] : // CHECK-SAME: fn_def_attrs {link_with = "/custom/path/to/ukernels/mm.o"} @@ -67,28 +67,29 @@ func.func @generic_matmul_i32i32i32_pad_pack(%arg0 : tensor, %arg1 #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> module { - func.func @generic_matmul_bf16bf16f32_pad_pack(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} { + func.func @generic_matmul_bf16bf16f32_pad_pack(%arg0: tensor<8x16x4x8xbf16>, %arg1: tensor<16x8x8x4xbf16>, + %arg2: tensor<16x16x4x4xf32>) -> tensor<16x16x4x4xf32> attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} { %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] - } ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) { + } ins(%arg0, %arg1 : tensor<8x16x4x8xbf16>, tensor<16x8x8x4xbf16>) + outs(%arg2 : tensor<16x16x4x4xf32>) { ^bb0(%in: bf16, %in_0: bf16, %out: f32): %1 = arith.extf %in : bf16 to f32 %2 = arith.extf %in_0 : bf16 to f32 %3 = arith.mulf %1, %2 : f32 %4 = arith.addf %out, %3 : f32 linalg.yield %4 : f32 - } -> tensor - return %0 : tensor + } -> tensor<16x16x4x4xf32> + return %0 : tensor<16x16x4x4xf32> } } // CHECK: func @generic_matmul_bf16bf16f32_pad_pack( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8x16x4x8xbf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<16x8x8x4xbf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<16x16x4x4xf32>) // CHECK-NOT: linalg.generic -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_bf16_f32" +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_bf16_bf16_f32_64x64x64_4x8x4" // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK-SAME: outs(%[[ARG2]] : // CHECK-SAME: fn_def_attrs {link_with = "/custom/path/to/ukernels/mm.o"} @@ -102,26 +103,32 @@ module { #map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> #map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> module { - func.func @generic_matmul_bf16bf16bf16_pad_pack(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} { - %0 = linalg.generic {indexing_maps = [#map, #map1, #map2], - iterator_types = ["parallel", "parallel", "reduction", - "parallel", "parallel", "reduction"] - } ins(%arg0, %arg1 : tensor, tensor) - outs(%arg2 : tensor) { - ^bb0(%in: bf16, %in_0: bf16, %out: bf16): - %3 = arith.mulf %in, %in_0 : bf16 - %4 = arith.addf %out, %3 : bf16 - linalg.yield %4 : bf16 - } -> tensor - return %0 : tensor + func.func @generic_matmul_bf16bf16f32_pack_peel(%arg0: tensor<1x1x8x16x4x8xbf16>, %arg1: tensor<1x1x16x8x8x4xbf16>, + %arg2: tensor<1x1x16x16x4x4xf32>) -> tensor<1x1x16x16x4x4xf32> attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} { + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], + iterator_types = ["parallel", "parallel", "reduction", + "parallel", "parallel", "reduction", + "parallel", "parallel", "reduction"] + } ins(%arg0, %arg1 : tensor<1x1x8x16x4x8xbf16>, tensor<1x1x16x8x8x4xbf16>) + outs(%arg2 : tensor<1x1x16x16x4x4xf32>) { + ^bb0(%in: bf16, %in_0: bf16, %out: f32): + %1 = arith.extf %in : bf16 to f32 + %2 = arith.extf %in_0 : bf16 to f32 + %3 = arith.mulf %1, %2 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<1x1x16x16x4x4xf32> + return %0 : tensor<1x1x16x16x4x4xf32> } } -// CHECK: func @generic_matmul_bf16bf16bf16_pad_pack( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) +// CHECK: func @generic_matmul_bf16bf16f32_pack_peel( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x1x8x16x4x8xbf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x1x16x8x8x4xbf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<1x1x16x16x4x4xf32>) // CHECK-NOT: linalg.generic -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_bf16_bf16" +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_bf16_bf16_f32_64x64x64_4x8x4" // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK-SAME: outs(%[[ARG2]] : // CHECK-SAME: fn_def_attrs {link_with = "/custom/path/to/ukernels/mm.o"} @@ -130,17 +137,56 @@ module { // ----- -func.func @zero_fill(%arg0 : tensor) -> tensor attributes { +#executable_target_amdaie_xclbin_fb = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "all"}> +#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)> +#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)> +#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)> +module { + func.func @generic_matmul_bf16bf16f32_pack_peel_objectfifo(%arg0: tensor<1x1x4x8x4x8xbf16>, %arg1: tensor<1x1x8x4x8x4xbf16>, + %arg2: tensor<1x1x8x8x4x4xf32>) -> tensor<1x1x8x8x4x4xf32> attributes {hal.executable.target = #executable_target_amdaie_xclbin_fb} { + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d2, d5, d3, d6, d8)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d2, d1, d4, d5, d8, d7)>, + affine_map<(d0, d1, d2, d3, d4, d5, d6, d7, d8) -> (d0, d1, d4, d3, d6, d7)>], + iterator_types = ["parallel", "parallel", "reduction", + "parallel", "parallel", "reduction", + "parallel", "parallel", "reduction"] + } ins(%arg0, %arg1 : tensor<1x1x4x8x4x8xbf16>, tensor<1x1x8x4x8x4xbf16>) + outs(%arg2 : tensor<1x1x8x8x4x4xf32>) { + ^bb0(%in: bf16, %in_0: bf16, %out: f32): + %1 = arith.extf %in : bf16 to f32 + %2 = arith.extf %in_0 : bf16 to f32 + %3 = arith.mulf %1, %2 : f32 + %4 = arith.addf %out, %3 : f32 + linalg.yield %4 : f32 + } -> tensor<1x1x8x8x4x4xf32> + return %0 : tensor<1x1x8x8x4x4xf32> + } +} +// CHECK: func @generic_matmul_bf16bf16f32_pack_peel_objectfifo( +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<1x1x4x8x4x8xbf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<1x1x8x4x8x4xbf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<1x1x8x8x4x4xf32>) +// CHECK-NOT: linalg.generic +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_bf16_bf16_f32_32x32x32_4x8x4" +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : +// CHECK-SAME: outs(%[[ARG2]] : +// CHECK-SAME: fn_def_attrs {link_with = "/custom/path/to/ukernels/mm.o"} +// CHECK-SAME: strided_outer_dims(0) +// CHECK: return %[[MICRO_KERNEL]] + +// ----- + +func.func @zero_fill(%arg0 : tensor<16x16x4x4xbf16>) -> tensor<16x16x4x4xbf16> attributes { hal.executable.target = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "all"}> } { %cst = arith.constant 0.0 : bf16 - %fill = linalg.fill ins(%cst : bf16) outs(%arg0 : tensor) -> tensor - return %fill : tensor + %fill = linalg.fill ins(%cst : bf16) outs(%arg0 : tensor<16x16x4x4xbf16>) -> tensor<16x16x4x4xbf16> + return %fill : tensor<16x16x4x4xbf16> } // CHECK: func @zero_fill( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor) +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<16x16x4x4xbf16>) // CHECK-NOT: linalg.fill -// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "zero_bf16" +// CHECK: %[[MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "zero_bf16_64x64" // CHECK-SAME: outs(%[[ARG0]] : // CHECK-SAME: fn_def_attrs {link_with = "/custom/path/to/ukernels/mm.o"} // CHECK-SAME: strided_outer_dims(0) @@ -148,12 +194,12 @@ func.func @zero_fill(%arg0 : tensor) -> tensor attri // ----- -func.func @non_zero_fill(%arg0 : tensor) -> tensor attributes { +func.func @non_zero_fill(%arg0 : tensor<16x16x4x4xbf16>) -> tensor<16x16x4x4xbf16> attributes { hal.executable.target = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "all"}> } { %cst = arith.constant 7.0 : bf16 - %fill = linalg.fill ins(%cst : bf16) outs(%arg0 : tensor) -> tensor - return %fill : tensor + %fill = linalg.fill ins(%cst : bf16) outs(%arg0 : tensor<16x16x4x4xbf16>) -> tensor<16x16x4x4xbf16> + return %fill : tensor<16x16x4x4xbf16> } // CHECK: func @non_zero_fill // CHECK: linalg.fill @@ -161,38 +207,38 @@ func.func @non_zero_fill(%arg0 : tensor) -> tensor a // ----- -func.func @zero_fill_with_matmul(%arg0 : tensor, %arg1 : tensor, - %arg2 : tensor) -> tensor attributes { +func.func @zero_fill_with_matmul(%arg0 : tensor<8x16x4x8xbf16>, %arg1 : tensor<16x8x8x4xbf16>, + %arg2 : tensor<16x16x4x4xbf16>) -> tensor<16x16x4x4xbf16> attributes { hal.executable.target = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "all"}> } { %cst = arith.constant 0.0 : bf16 - %fill = linalg.fill ins(%cst : bf16) outs(%arg2 : tensor) -> tensor + %fill = linalg.fill ins(%cst : bf16) outs(%arg2 : tensor<16x16x4x4xbf16>) -> tensor<16x16x4x4xbf16> %matmul = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] - } ins(%arg0, %arg1 : tensor, tensor) - outs(%fill : tensor) + } ins(%arg0, %arg1 : tensor<8x16x4x8xbf16>, tensor<16x8x8x4xbf16>) + outs(%fill : tensor<16x16x4x4xbf16>) { ^bb0(%in: bf16, %in_9: bf16, %out: bf16): %22 = arith.mulf %in, %in_9 : bf16 %23 = arith.addf %out, %22 : bf16 linalg.yield %23 : bf16 - } -> tensor - return %matmul : tensor + } -> tensor<16x16x4x4xbf16> + return %matmul : tensor<16x16x4x4xbf16> } // CHECK: func @zero_fill_with_matmul( -// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor -// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor) +// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<8x16x4x8xbf16> +// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<16x8x8x4xbf16> +// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<16x16x4x4xbf16>) // CHECK-NOT: linalg.fill -// CHECK: %[[ZERO_FILL_MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "zero_bf16" +// CHECK: %[[ZERO_FILL_MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "zero_bf16_64x64" // CHECK-SAME: outs(%[[ARG2]] : // CHECK-SAME: fn_def_attrs {link_with = "/custom/path/to/ukernels/mm.o"} // CHECK-SAME: strided_outer_dims(0) // CHECK-NOT: linalg.generic -// CHECK: %[[MATMUL_MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_bf16_bf16" +// CHECK: %[[MATMUL_MICRO_KERNEL:.+]] = iree_codegen.ukernel.generic "matmul_bf16_bf16_bf16_64x64x64_4x8x4" // CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : // CHECK-SAME: outs(%[[ZERO_FILL_MICRO_KERNEL]] : // CHECK-SAME: fn_def_attrs {link_with = "/custom/path/to/ukernels/mm.o"} @@ -205,48 +251,47 @@ func.func @zero_fill_with_matmul(%arg0 : tensor, %arg1 : tensor x + 1. The linalg.matmul should be lowered // to a ukernel, as should the linalg.fill, but the final elementwise addition // should not be. -func.func @zero_fill_matmul_elmwise(%arg0 : tensor, %arg1 : tensor, - %arg2 : tensor) -> tensor attributes { +func.func @zero_fill_matmul_elmwise(%arg0 : tensor<8x16x4x8xbf16>, %arg1 : tensor<16x8x8x4xbf16>, + %arg2 : tensor<16x16x4x4xbf16>) -> tensor<16x16x4x4xbf16> attributes { hal.executable.target = #hal.executable.target<"amd-aie", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "all"}> } { %cst = arith.constant 0.0 : bf16 %cst_1 = arith.constant 1.0 : bf16 - %fill = linalg.fill ins(%cst : bf16) outs(%arg2 : tensor) -> tensor - %matmul = linalg.generic {indexing_maps = [ - affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, - affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>], + %fill = linalg.fill ins(%cst : bf16) outs(%arg2 : tensor<16x16x4x4xbf16>) -> tensor<16x16x4x4xbf16> + %matmul = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d2, d0, d3, d5)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d2, d5, d4)>, + affine_map<(d0, d1, d2, d3, d4, d5) -> (d1, d0, d3, d4)>], iterator_types = ["parallel", "parallel", "reduction", "parallel", "parallel", "reduction"] - } ins(%arg0, %arg1 : tensor, tensor) - outs(%fill : tensor) + } ins(%arg0, %arg1 : tensor<8x16x4x8xbf16>, tensor<16x8x8x4xbf16>) + outs(%fill : tensor<16x16x4x4xbf16>) { ^bb0(%in: bf16, %in_9: bf16, %out: bf16): %22 = arith.mulf %in, %in_9 : bf16 %23 = arith.addf %out, %22 : bf16 linalg.yield %23 : bf16 - } -> tensor + } -> tensor<16x16x4x4xbf16> // Perform an elementwise addition of 1 to the result of the matmul. %matmul_add = linalg.generic {indexing_maps = [ affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} - ins(%matmul : tensor) - outs(%arg2 : tensor) + ins(%matmul : tensor<16x16x4x4xbf16>) + outs(%arg2 : tensor<16x16x4x4xbf16>) { ^bb0(%in: bf16, %out: bf16): %1 = arith.addf %in, %cst_1 : bf16 linalg.yield %1 : bf16 - } -> tensor + } -> tensor<16x16x4x4xbf16> - return %matmul_add : tensor + return %matmul_add : tensor<16x16x4x4xbf16> } // CHECK: func @zero_fill_matmul_elmwise // CHECK-NOT: linalg.fill -// CHECK: iree_codegen.ukernel.generic "zero_bf16" +// CHECK: iree_codegen.ukernel.generic "zero_bf16_64x64" // CHECK-NOT: linalg.fill -// CHECK: iree_codegen.ukernel.generic "matmul_bf16_bf16" +// CHECK: iree_codegen.ukernel.generic "matmul_bf16_bf16_bf16_64x64x64_4x8x4" // CHECK: linalg.generic // CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel", "parallel"] // CHECK: return