Skip to content

Commit

Permalink
e2e matmul tests: support f64 (iree-org#19093)
Browse files Browse the repository at this point in the history
This adds some e2e matmul tests on CPU for f64 (double precision).
Motivation:
- This is technically supported by IREE.
- This was already > 50% implemented in the e2e matmul testing
framework, and the partly implemented status was not a good steady state
to be in.
- My real motivation is iree-org#19099, I'm
doing a round of AMD MFMA / data tiling improvements, thought we should
support a denser set of MFMA intrinsics, noticed this has been supported
since CDNA2 so we shouldn't be prevented from supporting it just because
our e2e matmul tests can't cover it.

Signed-off-by: Benoit Jacob <[email protected]>
  • Loading branch information
bjacob authored Nov 12, 2024
1 parent c29ea90 commit e8f755d
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 16 deletions.
12 changes: 9 additions & 3 deletions tests/e2e/matmul/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,10 @@ PREPROCESSING_PEEL = "--iree-llvmcpu-vector-pproc-strategy=peel"
test_type = "matmul",
) for dtype in [
"f32",
# "f64" (also supported for ArmSME, but not by the test generator)
# f64 disabled because it wasn't supported by the test generator at the time
# this was added. When adding it in the future, consider passing
# --iree-input-demote-f64-to-f32=false to the compiler.
# "f64"
] for transpose_lhs in [
True,
False,
Expand Down Expand Up @@ -135,12 +138,14 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [
),
compiler_flags = [
"--iree-opt-data-tiling",
] + ["--iree-llvmcpu-enable-ukernels=%s" % ("all" if use_uk else "none")],
] + [
"--iree-llvmcpu-enable-ukernels=%s" % ("all" if use_uk else "none"),
] + (["--iree-input-demote-f64-to-f32=false"] if acc_type == "f64" else []),
generator = ":generate_e2e_matmul_tests",
generator_args = [
"--lhs_rhs_type=%s" % lhs_rhs_type,
"--acc_type=%s" % acc_type,
],
] + (["--shapes=small"] if acc_type == "f64" else []),
tags = ([
# f16/bf16 trigger internal LLVM assertion errors on riscv and wasm.
"noriscv",
Expand Down Expand Up @@ -185,6 +190,7 @@ X86_64_AVX512_BF16 = X86_64_AVX512 + [
[
("i8", "i32"),
("f32", "f32"),
("f64", "f64"),
("f16", "f16"),
("f16", "f32"),
("bf16", "bf16"),
Expand Down
54 changes: 54 additions & 0 deletions tests/e2e/matmul/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,33 @@ iree_generated_e2e_runner_test(
"x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_cpu_dt_f64_f64
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f64"
"--acc_type=f64"
"--shapes=small"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"llvm-cpu"
DRIVERS
"local-task"
COMPILER_FLAGS
"--iree-opt-data-tiling"
"--iree-llvmcpu-enable-ukernels=none"
"--iree-input-demote-f64-to-f32=false"
LABELS

TARGET_CPU_FEATURES_VARIANTS
"generic"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_cpu_dt_f16_f16
Expand Down Expand Up @@ -346,6 +373,33 @@ iree_generated_e2e_runner_test(
"x86_64:avx512:+avx,+avx2,+fma,+f16c,+avx512f,+avx512vl,+avx512cd,+avx512bw,+avx512dq"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_cpu_dt_uk_f64_f64
TEST_TYPE
matmul
GENERATOR
"generate_e2e_matmul_tests.py"
GENERATOR_ARGS
"--lhs_rhs_type=f64"
"--acc_type=f64"
"--shapes=small"
TEST_RUNNER
iree_tools_testing_e2e_iree-e2e-matmul-test
TARGET_BACKENDS
"llvm-cpu"
DRIVERS
"local-task"
COMPILER_FLAGS
"--iree-opt-data-tiling"
"--iree-llvmcpu-enable-ukernels=all"
"--iree-input-demote-f64-to-f32=false"
LABELS

TARGET_CPU_FEATURES_VARIANTS
"generic"
)

iree_generated_e2e_runner_test(
NAME
e2e_matmul_cpu_dt_uk_f16_f16
Expand Down
4 changes: 3 additions & 1 deletion tests/e2e/matmul/generate_e2e_matmul_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class MatrixElemTypeId(enum.Enum):
NONE = ""
I8 = "i8"
I32 = "i32"
F64 = "f64"
F32 = "f32"
F16 = "f16"
BF16 = "bf16"
Expand Down Expand Up @@ -896,6 +897,7 @@ def parse_arguments():
choices=[
"i32",
"i8",
"f64",
"f32",
"f16",
"bf16",
Expand All @@ -910,7 +912,7 @@ def parse_arguments():
parser.add_argument(
"--acc_type",
type=str,
choices=["i32", "f32", "f16", "bf16"],
choices=["i32", "f64", "f32", "f16", "bf16"],
help="Numeric type of the accumulator and result matrices",
required=True,
)
Expand Down
20 changes: 9 additions & 11 deletions tools/testing/e2e/iree-e2e-matmul-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,10 @@
result_data[n + m * n_size] = acc; \
}

// Reference mamtul instantiations from macro REFERENCE_MATMUL
// for the f32 input, f32 accumlation, and f32 result.
// [float <= float * float + float]
// Reference matmul instantiations
REFERENCE_MATMUL(float, float, float, float)

// Reference mamtul instantiations from macro REFERENCE_MATMUL
// for the int8_t input, int32_t accumlation, and int32_t result.
// [i32 <= i8 * i8 + i32]
REFERENCE_MATMUL(double, double, double, double)
REFERENCE_MATMUL(int8_t, int8_t, int32_t, int32_t)

// Reference mamtul instantiations from macro REFERENCE_MATMUL
// for the int32_t input, int32_t accumlation, and int32_t result.
// [i32 <= i32 * i32 + i32]
REFERENCE_MATMUL(int32_t, int32_t, int32_t, int32_t)

// Reference mamtul for the f16 input, f16 accumlation, and f16 result.
Expand Down Expand Up @@ -166,6 +157,13 @@ static iree_status_t reference_matmul_element(
m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const float*)lhs_data, (const float*)rhs_data, (const float*)acc_data,
(float*)result_data, m, n);
} else if (lhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_64 &&
rhs_type == IREE_HAL_ELEMENT_TYPE_FLOAT_64 &&
acc_type == IREE_HAL_ELEMENT_TYPE_FLOAT_64) {
reference_matmul_double_double_double_double(
m_size, k_size, n_size, lhs_type, rhs_type, acc_type, transpose_rhs,
(const double*)lhs_data, (const double*)rhs_data,
(const double*)acc_data, (double*)result_data, m, n);
} else if (iree_hal_element_type_is_integer(lhs_type, 8) &&
iree_hal_element_type_is_integer(rhs_type, 8) &&
iree_hal_element_type_is_integer(acc_type, 32)) {
Expand Down
15 changes: 14 additions & 1 deletion tools/testing/e2e/test_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,13 @@ iree_test_utils_e2e_value_t iree_test_utils_value_make_f32(float value) {
return result;
}

iree_test_utils_e2e_value_t iree_test_utils_value_make_f64(float value) {
iree_test_utils_e2e_value_t result;
result.type = IREE_TEST_UTILS_VALUE_TYPE_F64;
result.f64 = value;
return result;
}

iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
iree_hal_dim_t index, iree_hal_element_type_t result_type,
const void* data) {
Expand All @@ -167,6 +174,8 @@ iree_test_utils_e2e_value_t iree_test_utils_read_buffer_element(
return iree_test_utils_value_make_bf16(((uint16_t*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_32) {
return iree_test_utils_value_make_f32(((float*)data)[index]);
} else if (result_type == IREE_HAL_ELEMENT_TYPE_FLOAT_64) {
return iree_test_utils_value_make_f64(((double*)data)[index]);
}
iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled matmul result type"));
Expand Down Expand Up @@ -273,6 +282,10 @@ bool iree_test_utils_result_elements_agree(iree_test_utils_e2e_value_t expected,
if (actual.f32 == expected.f32) return true;
if (iree_test_utils_require_exact_results()) return false;
return fabsf(actual.f32 - expected.f32) < acceptable_fp_delta;
case IREE_TEST_UTILS_VALUE_TYPE_F64:
if (actual.f64 == expected.f64) return true;
if (iree_test_utils_require_exact_results()) return false;
return fabs(actual.f64 - expected.f64) < acceptable_fp_delta;
default:
iree_status_abort(iree_make_status(IREE_STATUS_INVALID_ARGUMENT,
"unhandled value type"));
Expand Down Expand Up @@ -387,7 +400,7 @@ void iree_test_utils_get_min_max_for_element_type(
case IREE_HAL_ELEMENT_TYPE_SINT_64:
case IREE_HAL_ELEMENT_TYPE_FLOAT_64:
*min = -16;
*min = +16;
*max = +16;
break;
case IREE_HAL_ELEMENT_TYPE_UINT_64:
*min = 0;
Expand Down

0 comments on commit e8f755d

Please sign in to comment.