From f7d6d06986331a1da6982cd9854379922ec97919 Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 12 Dec 2022 16:33:28 +0800 Subject: [PATCH 1/4] [SYCL] Add bfloat16 generic implementation for fmax, fmin Signed:sign-off-by: jinge90 --- .../ext/oneapi/experimental/bfloat16_math.hpp | 78 ++++++++++++------- 1 file changed, 51 insertions(+), 27 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp index 26d3d3bdc7ec3..dee68e38e66fd 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp @@ -30,6 +30,14 @@ uint32_t to_uint32_t(sycl::marray x, size_t start) { } } // namespace detail +// According to bfloat16 format, NAN value's exponent field is 0xFF and +// significand has non-zero bits. +template +std::enable_if_t::value, T> isnan(T x) { + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false; +} + template std::enable_if_t::value, T> fabs(T x) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) @@ -74,20 +82,31 @@ std::enable_if_t::value, T> fmin(T x, T y) { oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits)); #else - std::ignore = x; - std::ignore = y; - throw runtime_error( - "bfloat16 math functions are not currently supported on the host device.", - PI_ERROR_INVALID_DEVICE); + static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0; + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); + if (isnan(x) && isnan(y)) + return oneapi::detail::bitsToBfloat16(CanonicalNan); + + if (isnan(x)) + return y; + else if (isnan(y)) + return x; + else if (((XBits | YBits) == + static_cast(0x8000)) && + !(XBits & YBits)) + return oneapi::detail::bitsToBfloat16( + static_cast(0x8000)); + else + return (x < y) ? x : y; #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } template sycl::marray fmin(sycl::marray x, sycl::marray y) { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) sycl::marray res; - +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) for (size_t i = 0; i < N / 2; i++) { auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2)); @@ -101,15 +120,12 @@ sycl::marray fmin(sycl::marray x, oneapi::detail::bfloat16ToBits(y[N - 1]); res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits)); } - - return res; #else - std::ignore = x; - std::ignore = y; - throw runtime_error( - "bfloat16 math functions are not currently supported on the host device.", - PI_ERROR_INVALID_DEVICE); + for (size_t i = 0; i < N; i++) { + res[i] = fmin(x[i], y[i]); + } #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return res; } template @@ -119,20 +135,30 @@ std::enable_if_t::value, T> fmax(T x, T y) { oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits)); #else - std::ignore = x; - std::ignore = y; - throw runtime_error( - "bfloat16 math functions are not currently supported on the host device.", - PI_ERROR_INVALID_DEVICE); + static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0; + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); + if (isnan(x) && isnan(y)) + return oneapi::detail::bitsToBfloat16(CanonicalNan); + + if (isnan(x)) + return y; + else if (isnan(y)) + return x; + else if (((XBits | YBits) == + static_cast(0x8000)) && + !(XBits & YBits)) + return oneapi::detail::bitsToBfloat16(0); + else + return (x > y) ? x : y; #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } template sycl::marray fmax(sycl::marray x, sycl::marray y) { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) sycl::marray res; - +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) for (size_t i = 0; i < N / 2; i++) { auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2)); @@ -146,14 +172,12 @@ sycl::marray fmax(sycl::marray x, oneapi::detail::bfloat16ToBits(y[N - 1]); res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits)); } - return res; #else - std::ignore = x; - std::ignore = y; - throw runtime_error( - "bfloat16 math functions are not currently supported on the host device.", - PI_ERROR_INVALID_DEVICE); + for (size_t i = 0; i < N; i++) { + res[i] = fmax(x[i], y[i]); + } #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return res; } template From 293fb427cb652c1e77d5ad78408ce1870fd2f44d Mon Sep 17 00:00:00 2001 From: jinge90 Date: Mon, 12 Dec 2022 23:14:25 +0800 Subject: [PATCH 2/4] 1.Fix return type of isnan(bfloat16) 2. Use sycl::isnan in sycl::ext::oneapi::experimental::complex explicitly to avoid conflicting with bfloat16 isnan Signed-off-by: jinge90 --- sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp | 2 +- sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp index dee68e38e66fd..af75b42cfce08 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp @@ -33,7 +33,7 @@ uint32_t to_uint32_t(sycl::marray x, size_t start) { // According to bfloat16 format, NAN value's exponent field is 0xFF and // significand has non-zero bits. template -std::enable_if_t::value, T> isnan(T x) { +std::enable_if_t::value, bool> isnan(T x) { oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false; } diff --git a/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp b/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp index 5273066c07f51..cbde3b35de21f 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp @@ -1201,7 +1201,7 @@ SYCL_EXTERNAL complex<_Tp> acos(const complex<_Tp> &__x) { } if (sycl::isinf(__x.imag())) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); - if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag()))) + if (__x.real() == 0 && (__x.imag() == 0 || sycl::isnan(__x.imag()))) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); if (sycl::signbit(__x.imag())) From ea8be62c2e77ae5bddcac449c90e6df2dcc8cb10 Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Thu, 15 Dec 2022 08:37:35 -0800 Subject: [PATCH 3/4] Apply suggestions from code review --- .../ext/oneapi/experimental/bfloat16_math.hpp | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp index af75b42cfce08..ffc959dcaaeb6 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp @@ -90,15 +90,15 @@ std::enable_if_t::value, T> fmin(T x, T y) { if (isnan(x)) return y; - else if (isnan(y)) + if (isnan(y)) return x; - else if (((XBits | YBits) == + if (((XBits | YBits) == static_cast(0x8000)) && !(XBits & YBits)) return oneapi::detail::bitsToBfloat16( static_cast(0x8000)); - else - return (x < y) ? x : y; + + return (x < y) ? x : y; #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } @@ -143,14 +143,14 @@ std::enable_if_t::value, T> fmax(T x, T y) { if (isnan(x)) return y; - else if (isnan(y)) + if (isnan(y)) return x; - else if (((XBits | YBits) == + if (((XBits | YBits) == static_cast(0x8000)) && !(XBits & YBits)) return oneapi::detail::bitsToBfloat16(0); - else - return (x > y) ? x : y; + + return (x > y) ? x : y; #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } From 1bc646ba20c520bd080b970e530fe37178699bee Mon Sep 17 00:00:00 2001 From: Alexey Bader Date: Thu, 15 Dec 2022 08:42:28 -0800 Subject: [PATCH 4/4] Fix formatting. --- .../sycl/ext/oneapi/experimental/bfloat16_math.hpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp index ffc959dcaaeb6..2600932d797c6 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp @@ -93,11 +93,11 @@ std::enable_if_t::value, T> fmin(T x, T y) { if (isnan(y)) return x; if (((XBits | YBits) == - static_cast(0x8000)) && - !(XBits & YBits)) + static_cast(0x8000)) && + !(XBits & YBits)) return oneapi::detail::bitsToBfloat16( static_cast(0x8000)); - + return (x < y) ? x : y; #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } @@ -146,8 +146,8 @@ std::enable_if_t::value, T> fmax(T x, T y) { if (isnan(y)) return x; if (((XBits | YBits) == - static_cast(0x8000)) && - !(XBits & YBits)) + static_cast(0x8000)) && + !(XBits & YBits)) return oneapi::detail::bitsToBfloat16(0); return (x > y) ? x : y;