Skip to content

Commit aa2d4f4

Browse files
ville-kbenoitsteiner
authored andcommitted
Avoid functions that might not be defined on SYCL device (#51)
* Avoid functions that might not be defined on SYCL device * Simplify by using Eigen math functions
1 parent ecd6f7a commit aa2d4f4

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

tensorflow/core/lib/random/random_distributions.h

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ limitations under the License.
2727
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
2828
#include "tensorflow/core/lib/random/philox_random.h"
2929

30+
3031
namespace tensorflow {
3132
namespace random {
3233

@@ -373,7 +374,7 @@ class TruncatedNormalDistribution<SingleSampleGenerator, Eigen::half> {
373374
BoxMullerFloat(x0, x1, &f[0], &f[1]);
374375

375376
for (int i = 0; i < 2; ++i) {
376-
if (fabs(f[i]) < kTruncateValue) {
377+
if (Eigen::numext::abs(f[i]) < kTruncateValue) {
377378
results[index++] = Eigen::half(f[i]);
378379
if (index >= kResultElementCount) {
379380
return results;
@@ -416,7 +417,7 @@ class TruncatedNormalDistribution<SingleSampleGenerator, float> {
416417
BoxMullerFloat(x0, x1, &f[0], &f[1]);
417418

418419
for (int i = 0; i < 2; ++i) {
419-
if (fabs(f[i]) < kTruncateValue) {
420+
if (Eigen::numext::abs(f[i]) < kTruncateValue) {
420421
results[index++] = f[i];
421422
if (index >= kResultElementCount) {
422423
return results;
@@ -458,7 +459,7 @@ class TruncatedNormalDistribution<SingleSampleGenerator, double> {
458459
BoxMullerDouble(x0, x1, x2, x3, &d[0], &d[1]);
459460

460461
for (int i = 0; i < 2; ++i) {
461-
if (fabs(d[i]) < kTruncateValue) {
462+
if (Eigen::numext::abs(d[i]) < kTruncateValue) {
462463
results[index++] = d[i];
463464
if (index >= kResultElementCount) {
464465
return results;
@@ -483,12 +484,12 @@ void BoxMullerFloat(uint32 x0, uint32 x1, float* f0, float* f1) {
483484
u1 = epsilon;
484485
}
485486
const float v1 = 2.0f * M_PI * Uint32ToFloat(x1);
486-
const float u2 = sqrt(-2.0f * log(u1));
487-
#if defined(__linux__)
488-
sincosf(v1, f0, f1);
487+
const float u2 = Eigen::numext::sqrt(-2.0f * Eigen::numext::log(u1));
488+
#if defined(TENSORFLOW_USE_SYCL) || !defined(__linux__)
489+
*f0 = Eigen::numext::sin(v1);
490+
*f1 = Eigen::numext::cos(v1);
489491
#else
490-
*f0 = sinf(v1);
491-
*f1 = cosf(v1);
492+
sincosf(v1, f0, f1);
492493
#endif
493494
*f0 *= u2;
494495
*f1 *= u2;
@@ -509,12 +510,12 @@ void BoxMullerDouble(uint32 x0, uint32 x1, uint32 x2, uint32 x3, double* d0,
509510
u1 = epsilon;
510511
}
511512
const double v1 = 2 * M_PI * Uint64ToDouble(x2, x3);
512-
const double u2 = sqrt(-2.0 * log(u1));
513-
#if defined(__linux__)
514-
sincos(v1, d0, d1);
513+
const double u2 = Eigen::numext::sqrt(-2.0 * Eigen::numext::log(u1));
514+
#if defined(TENSORFLOW_USE_SYCL) || !defined(__linux__)
515+
*d0 = Eigen::numext::sin(v1);
516+
*d1 = Eigen::numext::cos(v1);
515517
#else
516-
*d0 = sin(v1);
517-
*d1 = cos(v1);
518+
sincos(v1, d0, d1);
518519
#endif
519520
*d0 *= u2;
520521
*d1 *= u2;

0 commit comments

Comments
 (0)