From ceba0cf75d578822ac2e3b42904eec6f19d38f5c Mon Sep 17 00:00:00 2001 From: Hugh Delaney Date: Mon, 7 Mar 2022 17:18:15 +0000 Subject: [PATCH 1/6] Adding tests for fma relu with half, bf16 and bf16x2 datatypes --- SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp | 275 ++++++++++++++++++++ 1 file changed, 275 insertions(+) create mode 100644 SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp diff --git a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp new file mode 100644 index 0000000000..54cbbb3a40 --- /dev/null +++ b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp @@ -0,0 +1,275 @@ +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: %GPU_RUN_PLACEHOLDER %t.out +// +// Only cuda backend implements bf16 +// REQUIRES: cuda + +#include + +constexpr int N = 32; // All vector sizes divide this + +using namespace sycl; + +float make_fp32(uint16_t x) { + uint32_t y = x; + y = y << 16; + float *res = reinterpret_cast(&y); + return *res; +} + +uint16_t make_bf16(float x) { + uint32_t *res = reinterpret_cast(&x); + *res = *res >> 16; + return (uint16_t)*res; +} + +bool compare_fma_relu_bf16(uint16_t a, uint16_t b, uint16_t c, uint16_t d) { + uint32_t a_tmp = a, b_tmp = b, c_tmp = c, d_tmp = d; + a_tmp <<= 16; + b_tmp <<= 16; + c_tmp <<= 16; + d_tmp <<= 16; + float *a_ptr = reinterpret_cast(&a_tmp), + *b_ptr = reinterpret_cast(&b_tmp), + *c_ptr = reinterpret_cast(&c_tmp), + *d_ptr = reinterpret_cast(&d_tmp); + float tmp_ret = std::fma(*a_ptr, *b_ptr, *c_ptr); + tmp_ret = tmp_ret > 0 ? tmp_ret : 0; + + return fabs(tmp_ret - *d_ptr) <= + 8 * fabs(*d_ptr) * std::numeric_limits::epsilon(); +} + +bool compare_fma_relu_bf16x2(uint32_t a, uint32_t b, uint32_t c, uint32_t d) { + uint16_t *a_beg = reinterpret_cast(&a), + *b_beg = reinterpret_cast(&b), + *c_beg = reinterpret_cast(&c), + *d_beg = reinterpret_cast(&d); + return compare_fma_relu_bf16(*a_beg, *b_beg, *c_beg, *d_beg) && + compare_fma_relu_bf16(*(a_beg + 1), *(b_beg + 1), *(c_beg + 1), + *(d_beg + 1)); +} + +#define TEST_BUILTIN_HALF_SCAL_IMPL(NAME) \ + { \ + buffer a_buf(&a[0], N); \ + buffer b_buf(&b[0], N); \ + buffer c_buf(&c[0], N); \ + buffer d_buf(&d[0], N); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N, [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + assert(fabs(d[i] - NAME(a[i], b[i], c[i])) < \ + std::numeric_limits::epsilon()); \ + } + +#define TEST_BUILTIN_HALF_VEC_IMPL(NAME, SZ) \ + { \ + buffer a_buf((half##SZ *)&a[0], N / SZ); \ + buffer b_buf((half##SZ *)&b[0], N / SZ); \ + buffer c_buf((half##SZ *)&c[0], N / SZ); \ + buffer d_buf((half##SZ *)&d[0], N / SZ); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N / SZ, [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + assert(fabs(d[i] - NAME(a[i], b[i], c[i])) < \ + std::numeric_limits::epsilon()); \ + } + +#define TEST_BUILTIN_HALF_VEC3_IMPL(NAME) \ + { \ + buffer a_buf((half3 *)&a[0], N / 4); \ + buffer b_buf((half3 *)&b[0], N / 4); \ + buffer c_buf((half3 *)&c[0], N / 4); \ + buffer d_buf((half3 *)&d[0], N / 4); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N / 4, [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + if (i % 4 != 3) { \ + assert(fabs(d[i] - NAME(a[i], b[i], c[i])) < \ + std::numeric_limits::epsilon()); \ + } \ + } + +#define TEST_BUILTIN_BF16_SCAL_IMPL(NAME) \ + { \ + buffer a_buf(&a[0], N); \ + buffer b_buf(&b[0], N); \ + buffer c_buf(&c[0], N); \ + buffer d_buf(&d[0], N); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N, [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + assert(compare_fma_relu_bf16(a[i], b[i], c[i], d[i])); \ + } + +#define TEST_BUILTIN_BF16_VEC_IMPL(NAME, SZ) \ + { \ + buffer a_buf((ushort##SZ *)&a[0], N / SZ); \ + buffer b_buf((ushort##SZ *)&b[0], N / SZ); \ + buffer c_buf((ushort##SZ *)&c[0], N / SZ); \ + buffer d_buf((ushort##SZ *)&d[0], N / SZ); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N / SZ, [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + assert(compare_fma_relu_bf16x2(a[i], b[i], c[i], d[i])); \ + } + +#define TEST_BUILTIN_BF16_VEC3_IMPL(NAME) \ + { \ + buffer a_buf((ushort3 *)&a[0], N / 4); \ + buffer b_buf((ushort3 *)&b[0], N / 4); \ + buffer c_buf((ushort3 *)&c[0], N / 4); \ + buffer d_buf((ushort3 *)&d[0], N / 4); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N / 4, [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + if (i % 4 != 3) { \ + assert(compare_fma_relu_bf16x2(a[i], b[i], c[i], d[i])); \ + } \ + } + +#define TEST_BUILTIN_BF16X2_SCAL_IMPL(NAME) \ + { \ + buffer a_buf((uint32_t *)&a[0], N / 2); \ + buffer b_buf((uint32_t *)&b[0], N / 2); \ + buffer c_buf((uint32_t *)&c[0], N / 2); \ + buffer d_buf((uint32_t *)&d[0], N / 2); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N / 2, [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + assert(compare_fma_relu_bf16x2(a[i], b[i], c[i], d[i])); \ + } + +#define TEST_BUILTIN_BF16X2_VEC_IMPL(NAME, SZ) \ + { \ + buffer a_buf((uint##SZ *)&a[0], N / (2 * SZ)); \ + buffer b_buf((uint##SZ *)&b[0], N / (2 * SZ)); \ + buffer c_buf((uint##SZ *)&c[0], N / (2 * SZ)); \ + buffer d_buf((uint##SZ *)&d[0], N / (2 * SZ)); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N / (2 * SZ), [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + assert(compare_fma_relu_bf16x2(a[i], b[i], c[i], d[i])); \ + } + +#define TEST_BUILTIN_BF16X2_VEC3_IMPL(NAME) \ + { \ + buffer a_buf((uint3 *)&a[0], N / (2 * 4)); \ + buffer b_buf((uint3 *)&b[0], N / (2 * 4)); \ + buffer c_buf((uint3 *)&c[0], N / (2 * 4)); \ + buffer d_buf((uint3 *)&d[0], N / (2 * 4)); \ + q.submit([&](handler &cgh) { \ + auto A = a_buf.get_access(cgh); \ + auto B = b_buf.get_access(cgh); \ + auto C = c_buf.get_access(cgh); \ + auto D = d_buf.get_access(cgh); \ + cgh.parallel_for(N / (2 * 4), [=](id<1> index) { \ + D[index] = NAME(A[index], B[index], C[index]); \ + }); \ + }); \ + } \ + for (int i = 0; i < N; i++) { \ + if (i % 8 > 5) { \ + assert(compare_fma_relu_bf16x2(a[i], b[i], c[i], d[i])); \ + } \ + } + +#define TEST_BUILTIN(NAME, type) \ + TEST_BUILTIN_##type##_SCAL_IMPL(NAME); \ + TEST_BUILTIN_##type##_VEC_IMPL(NAME, 2); \ + TEST_BUILTIN_##type##_VEC3_IMPL(NAME); \ + TEST_BUILTIN_##type##_VEC_IMPL(NAME, 4); \ + TEST_BUILTIN_##type##_VEC_IMPL(NAME, 8); \ + TEST_BUILTIN_##type##_VEC_IMPL(NAME, 16); + +int main() { + queue q; + + // HALF tests + { + std::vector a(N), b(N), c(N), d(N); + for (int i = 0; i < N; i++) { + a[i] = i / (half)N; + b[i] = (N - i) / (half)N; + c[i] = -i / 2 * (half)N; + } + TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, HALF); + } + + // BF16, BF16X2 tests + { + std::vector a(N), b(N), c(N), d(N); + for (int i = 0; i < N; i++) { + a[i] = make_bf16(i / (float)N); + b[i] = make_bf16((N - i) / (float)N); + c[i] = make_bf16(-i / 10 * (float)N); + } + TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, BF16); + TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, BF16X2); + } +} From f678f9e8a929c9de943a418cf1457663bb59fafd Mon Sep 17 00:00:00 2001 From: Hugh Delaney Date: Tue, 8 Mar 2022 10:53:31 +0000 Subject: [PATCH 2/6] Fix typo --- SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp index 54cbbb3a40..14785ff6ff 100644 --- a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp +++ b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp @@ -152,7 +152,7 @@ bool compare_fma_relu_bf16x2(uint32_t a, uint32_t b, uint32_t c, uint32_t d) { }); \ } \ for (int i = 0; i < N; i++) { \ - assert(compare_fma_relu_bf16x2(a[i], b[i], c[i], d[i])); \ + assert(compare_fma_relu_bf16(a[i], b[i], c[i], d[i])); \ } #define TEST_BUILTIN_BF16_VEC3_IMPL(NAME) \ @@ -173,7 +173,7 @@ bool compare_fma_relu_bf16x2(uint32_t a, uint32_t b, uint32_t c, uint32_t d) { } \ for (int i = 0; i < N; i++) { \ if (i % 4 != 3) { \ - assert(compare_fma_relu_bf16x2(a[i], b[i], c[i], d[i])); \ + assert(compare_fma_relu_bf16(a[i], b[i], c[i], d[i])); \ } \ } From 6ed3085b69d9f13ad34b1e8679f9808bc857e5c1 Mon Sep 17 00:00:00 2001 From: Hugh Delaney Date: Tue, 8 Mar 2022 15:52:53 +0000 Subject: [PATCH 3/6] Changing initial vals for more nonzero tests --- SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp index 14785ff6ff..5b746bc018 100644 --- a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp +++ b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp @@ -256,7 +256,7 @@ int main() { for (int i = 0; i < N; i++) { a[i] = i / (half)N; b[i] = (N - i) / (half)N; - c[i] = -i / 2 * (half)N; + c[i] = -i / 4 / (half)N; } TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, HALF); } @@ -267,7 +267,7 @@ int main() { for (int i = 0; i < N; i++) { a[i] = make_bf16(i / (float)N); b[i] = make_bf16((N - i) / (float)N); - c[i] = make_bf16(-i / 10 * (float)N); + c[i] = make_bf16(-i / 4 / (float)N); } TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, BF16); TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, BF16X2); From ff45d26425a37572ba725c31ffb0943552387d8b Mon Sep 17 00:00:00 2001 From: Hugh Delaney Date: Thu, 10 Mar 2022 11:19:30 +0000 Subject: [PATCH 4/6] Fix typo --- SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp index 5b746bc018..3434845d06 100644 --- a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp +++ b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -Xsycl-target-backend --cuda-gpu-arch-sm_80 // RUN: %GPU_RUN_PLACEHOLDER %t.out // // Only cuda backend implements bf16 From e79c722f1500b6e7bfd5fe929c265bb8cf7f8c2a Mon Sep 17 00:00:00 2001 From: Hugh Delaney Date: Thu, 10 Mar 2022 11:21:15 +0000 Subject: [PATCH 5/6] Fix typo --- SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp index 3434845d06..754dff94df 100644 --- a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp +++ b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -Xsycl-target-backend --cuda-gpu-arch-sm_80 +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -Xsycl-target-backend --cuda-gpu-arch=sm_80 // RUN: %GPU_RUN_PLACEHOLDER %t.out // // Only cuda backend implements bf16 From 5e51dc3edb292eb61dc0487d8001ee531ca2658f Mon Sep 17 00:00:00 2001 From: Hugh Delaney Date: Mon, 4 Apr 2022 15:51:55 +0100 Subject: [PATCH 6/6] Updating test to take bfloat16 instead of uint16_t storage type --- SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp | 123 ++++++++------------ 1 file changed, 50 insertions(+), 73 deletions(-) diff --git a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp index 754dff94df..8f2c95fc6d 100644 --- a/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp +++ b/SYCL/BFloat16/fma_relu_half_bf16_bf16x2.cpp @@ -1,4 +1,4 @@ -// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -Xsycl-target-backend --cuda-gpu-arch=sm_80 +// RUN: %clangxx -fsycl -fsycl-targets=%sycl_triple %s -o %t.out -Xsycl-target-backend --cuda-gpu-arch=sm_80 // RUN: %GPU_RUN_PLACEHOLDER %t.out // // Only cuda backend implements bf16 @@ -9,42 +9,46 @@ constexpr int N = 32; // All vector sizes divide this using namespace sycl; +using sycl::ext::oneapi::experimental::bfloat16; -float make_fp32(uint16_t x) { - uint32_t y = x; +float make_fp32(bfloat16 x) { + uint32_t y = reinterpret_cast(x); y = y << 16; - float *res = reinterpret_cast(&y); - return *res; + float res = reinterpret_cast(y); + return res; } -uint16_t make_bf16(float x) { - uint32_t *res = reinterpret_cast(&x); - *res = *res >> 16; - return (uint16_t)*res; +bfloat16 make_bf16(float x) { + uint32_t res = reinterpret_cast(x); + res = res >> 16; + return reinterpret_cast(res); } -bool compare_fma_relu_bf16(uint16_t a, uint16_t b, uint16_t c, uint16_t d) { - uint32_t a_tmp = a, b_tmp = b, c_tmp = c, d_tmp = d; +bool compare_fma_relu_bf16(bfloat16 a, bfloat16 b, bfloat16 c, bfloat16 d) { + uint32_t a_tmp = reinterpret_cast(a), + b_tmp = reinterpret_cast(b), + c_tmp = reinterpret_cast(c), + d_tmp = reinterpret_cast(d); a_tmp <<= 16; b_tmp <<= 16; c_tmp <<= 16; d_tmp <<= 16; - float *a_ptr = reinterpret_cast(&a_tmp), - *b_ptr = reinterpret_cast(&b_tmp), - *c_ptr = reinterpret_cast(&c_tmp), - *d_ptr = reinterpret_cast(&d_tmp); - float tmp_ret = std::fma(*a_ptr, *b_ptr, *c_ptr); - tmp_ret = tmp_ret > 0 ? tmp_ret : 0; - - return fabs(tmp_ret - *d_ptr) <= - 8 * fabs(*d_ptr) * std::numeric_limits::epsilon(); + float a_float = reinterpret_cast(a_tmp), + b_float = reinterpret_cast(b_tmp), + c_float = reinterpret_cast(c_tmp), + d_float = reinterpret_cast(d_tmp); + float d_cmp = std::fma(a_float, b_float, c_float); + d_cmp = d_cmp > 0 ? d_cmp : 0; + + return fabs(d_float - d_cmp) <= + 8 * fabs(d_cmp) * std::numeric_limits::epsilon(); } bool compare_fma_relu_bf16x2(uint32_t a, uint32_t b, uint32_t c, uint32_t d) { - uint16_t *a_beg = reinterpret_cast(&a), - *b_beg = reinterpret_cast(&b), - *c_beg = reinterpret_cast(&c), - *d_beg = reinterpret_cast(&d); + bfloat16 *a_beg = reinterpret_cast(&a), + *b_beg = reinterpret_cast(&b), + *c_beg = reinterpret_cast(&c), + *d_beg = reinterpret_cast(&d); return compare_fma_relu_bf16(*a_beg, *b_beg, *c_beg, *d_beg) && compare_fma_relu_bf16(*(a_beg + 1), *(b_beg + 1), *(c_beg + 1), *(d_beg + 1)); @@ -115,12 +119,14 @@ bool compare_fma_relu_bf16x2(uint32_t a, uint32_t b, uint32_t c, uint32_t d) { } \ } +// There are currently no vec types implemented for the bfloat16 class +// TODO: test vec types once implemented #define TEST_BUILTIN_BF16_SCAL_IMPL(NAME) \ { \ - buffer a_buf(&a[0], N); \ - buffer b_buf(&b[0], N); \ - buffer c_buf(&c[0], N); \ - buffer d_buf(&d[0], N); \ + buffer a_buf(&a[0], N); \ + buffer b_buf(&b[0], N); \ + buffer c_buf(&c[0], N); \ + buffer d_buf(&d[0], N); \ q.submit([&](handler &cgh) { \ auto A = a_buf.get_access(cgh); \ auto B = b_buf.get_access(cgh); \ @@ -135,48 +141,6 @@ bool compare_fma_relu_bf16x2(uint32_t a, uint32_t b, uint32_t c, uint32_t d) { assert(compare_fma_relu_bf16(a[i], b[i], c[i], d[i])); \ } -#define TEST_BUILTIN_BF16_VEC_IMPL(NAME, SZ) \ - { \ - buffer a_buf((ushort##SZ *)&a[0], N / SZ); \ - buffer b_buf((ushort##SZ *)&b[0], N / SZ); \ - buffer c_buf((ushort##SZ *)&c[0], N / SZ); \ - buffer d_buf((ushort##SZ *)&d[0], N / SZ); \ - q.submit([&](handler &cgh) { \ - auto A = a_buf.get_access(cgh); \ - auto B = b_buf.get_access(cgh); \ - auto C = c_buf.get_access(cgh); \ - auto D = d_buf.get_access(cgh); \ - cgh.parallel_for(N / SZ, [=](id<1> index) { \ - D[index] = NAME(A[index], B[index], C[index]); \ - }); \ - }); \ - } \ - for (int i = 0; i < N; i++) { \ - assert(compare_fma_relu_bf16(a[i], b[i], c[i], d[i])); \ - } - -#define TEST_BUILTIN_BF16_VEC3_IMPL(NAME) \ - { \ - buffer a_buf((ushort3 *)&a[0], N / 4); \ - buffer b_buf((ushort3 *)&b[0], N / 4); \ - buffer c_buf((ushort3 *)&c[0], N / 4); \ - buffer d_buf((ushort3 *)&d[0], N / 4); \ - q.submit([&](handler &cgh) { \ - auto A = a_buf.get_access(cgh); \ - auto B = b_buf.get_access(cgh); \ - auto C = c_buf.get_access(cgh); \ - auto D = d_buf.get_access(cgh); \ - cgh.parallel_for(N / 4, [=](id<1> index) { \ - D[index] = NAME(A[index], B[index], C[index]); \ - }); \ - }); \ - } \ - for (int i = 0; i < N; i++) { \ - if (i % 4 != 3) { \ - assert(compare_fma_relu_bf16(a[i], b[i], c[i], d[i])); \ - } \ - } - #define TEST_BUILTIN_BF16X2_SCAL_IMPL(NAME) \ { \ buffer a_buf((uint32_t *)&a[0], N / 2); \ @@ -261,15 +225,28 @@ int main() { TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, HALF); } - // BF16, BF16X2 tests + // BF16 { - std::vector a(N), b(N), c(N), d(N); + std::vector a(N), b(N), c(N), d(N); for (int i = 0; i < N; i++) { a[i] = make_bf16(i / (float)N); b[i] = make_bf16((N - i) / (float)N); c[i] = make_bf16(-i / 4 / (float)N); } - TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, BF16); + TEST_BUILTIN_BF16_SCAL_IMPL(fma_relu); + } + + // BF16X2 + { + std::vector a(N), b(N), c(N), d(N); + for (int i = 0; i < N; i++) { + auto tmp = make_bf16(i / (float)N); + a[i] = reinterpret_cast(tmp); + tmp = make_bf16((N - i) / (float)N); + b[i] = reinterpret_cast(tmp); + tmp = make_bf16(-i / 4 / (float)N); + c[i] = reinterpret_cast(tmp); + } TEST_BUILTIN(sycl::ext::oneapi::experimental::fma_relu, BF16X2); } }