diff --git a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp index 26d3d3bdc7ec..2600932d797c 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/bfloat16_math.hpp @@ -30,6 +30,14 @@ uint32_t to_uint32_t(sycl::marray x, size_t start) { } } // namespace detail +// According to bfloat16 format, NAN value's exponent field is 0xFF and +// significand has non-zero bits. +template +std::enable_if_t::value, bool> isnan(T x) { + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + return (((XBits & 0x7F80) == 0x7F80) && (XBits & 0x7F)) ? true : false; +} + template std::enable_if_t::value, T> fabs(T x) { #if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) @@ -74,20 +82,31 @@ std::enable_if_t::value, T> fmin(T x, T y) { oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); return oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits)); #else - std::ignore = x; - std::ignore = y; - throw runtime_error( - "bfloat16 math functions are not currently supported on the host device.", - PI_ERROR_INVALID_DEVICE); + static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0; + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); + if (isnan(x) && isnan(y)) + return oneapi::detail::bitsToBfloat16(CanonicalNan); + + if (isnan(x)) + return y; + if (isnan(y)) + return x; + if (((XBits | YBits) == + static_cast(0x8000)) && + !(XBits & YBits)) + return oneapi::detail::bitsToBfloat16( + static_cast(0x8000)); + + return (x < y) ? x : y; #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } template sycl::marray fmin(sycl::marray x, sycl::marray y) { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) sycl::marray res; - +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) for (size_t i = 0; i < N / 2; i++) { auto partial_res = __clc_fmin(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2)); @@ -101,15 +120,12 @@ sycl::marray fmin(sycl::marray x, oneapi::detail::bfloat16ToBits(y[N - 1]); res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmin(XBits, YBits)); } - - return res; #else - std::ignore = x; - std::ignore = y; - throw runtime_error( - "bfloat16 math functions are not currently supported on the host device.", - PI_ERROR_INVALID_DEVICE); + for (size_t i = 0; i < N; i++) { + res[i] = fmin(x[i], y[i]); + } #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return res; } template @@ -119,20 +135,30 @@ std::enable_if_t::value, T> fmax(T x, T y) { oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); return oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits)); #else - std::ignore = x; - std::ignore = y; - throw runtime_error( - "bfloat16 math functions are not currently supported on the host device.", - PI_ERROR_INVALID_DEVICE); + static const oneapi::detail::Bfloat16StorageT CanonicalNan = 0x7FC0; + oneapi::detail::Bfloat16StorageT XBits = oneapi::detail::bfloat16ToBits(x); + oneapi::detail::Bfloat16StorageT YBits = oneapi::detail::bfloat16ToBits(y); + if (isnan(x) && isnan(y)) + return oneapi::detail::bitsToBfloat16(CanonicalNan); + + if (isnan(x)) + return y; + if (isnan(y)) + return x; + if (((XBits | YBits) == + static_cast(0x8000)) && + !(XBits & YBits)) + return oneapi::detail::bitsToBfloat16(0); + + return (x > y) ? x : y; #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) } template sycl::marray fmax(sycl::marray x, sycl::marray y) { -#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) sycl::marray res; - +#if defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) for (size_t i = 0; i < N / 2; i++) { auto partial_res = __clc_fmax(detail::to_uint32_t(x, i * 2), detail::to_uint32_t(y, i * 2)); @@ -146,14 +172,12 @@ sycl::marray fmax(sycl::marray x, oneapi::detail::bfloat16ToBits(y[N - 1]); res[N - 1] = oneapi::detail::bitsToBfloat16(__clc_fmax(XBits, YBits)); } - return res; #else - std::ignore = x; - std::ignore = y; - throw runtime_error( - "bfloat16 math functions are not currently supported on the host device.", - PI_ERROR_INVALID_DEVICE); + for (size_t i = 0; i < N; i++) { + res[i] = fmax(x[i], y[i]); + } #endif // defined(__SYCL_DEVICE_ONLY__) && defined(__NVPTX__) + return res; } template diff --git a/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp b/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp index 5273066c07f5..cbde3b35de21 100644 --- a/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp +++ b/sycl/include/sycl/ext/oneapi/experimental/sycl_complex.hpp @@ -1201,7 +1201,7 @@ SYCL_EXTERNAL complex<_Tp> acos(const complex<_Tp> &__x) { } if (sycl::isinf(__x.imag())) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); - if (__x.real() == 0 && (__x.imag() == 0 || isnan(__x.imag()))) + if (__x.real() == 0 && (__x.imag() == 0 || sycl::isnan(__x.imag()))) return complex<_Tp>(__pi / _Tp(2), -__x.imag()); complex<_Tp> __z = log(__x + sqrt(__sqr(__x) - _Tp(1))); if (sycl::signbit(__x.imag()))