@@ -20,18 +20,29 @@ limitations under the License.
20
20
#include " tensorflow/core/framework/bounds_check.h"
21
21
#include " tensorflow/core/framework/register_types.h"
22
22
#include " tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
23
+ #include " tensorflow/core/util/env_var.h"
24
+ #include " tensorflow/core/util/gpu_device_functions.h"
23
25
#include " tensorflow/core/util/gpu_kernel_helper.h"
24
26
25
27
namespace tensorflow {
26
28
27
29
typedef Eigen::GpuDevice GPUDevice;
28
30
29
- template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
31
+ __global__ void DownCast (
32
+ const int size, const double * src, float * __restrict__ dst) {
33
+
34
+ GPU_1D_KERNEL_LOOP (index , size) {
35
+ dst[index ] = (float )src[index ];
36
+ }
37
+ }
38
+
39
+ template <typename Tin, typename Tindices, typename Tout,
40
+ bool ADJ_A, bool ADJ_B>
30
41
__global__ void SparseTensorDenseMatMulKernel (int nnz, int m, int b_rows,
31
42
int b_cols, int p,
32
43
const Tindices* a_indices,
33
- const T * a_values, const T * b,
34
- T * out) {
44
+ const Tin * a_values, const Tin * b,
45
+ Tout * out) {
35
46
// out_{ij} = sum_k {a_ik b_kj}
36
47
// out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
37
48
const int n = (ADJ_B) ? b_cols : b_rows;
@@ -44,31 +55,42 @@ __global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
44
55
continue ; // Nowhere to signal an error :(
45
56
}
46
57
// out[i, j]
47
- T * out_location = out + i * p + j;
58
+ Tout * out_location = out + i * p + j;
48
59
if (!FastBoundsCheck (k, n)) {
49
- GpuAtomicAdd (out_location, std::numeric_limits<T >::quiet_NaN ());
60
+ GpuAtomicAdd (out_location, std::numeric_limits<Tout >::quiet_NaN ());
50
61
continue ;
51
62
}
52
63
53
64
// a_value == (ADJ_A) ? a[k, i] : a[i, k]
54
- const T a_value = ldg (a_values + a_ix);
65
+ const Tin a_value = ldg (a_values + a_ix);
55
66
56
67
// b_value == (ADJ_B) ? b[j, k] : b[k, j]
57
- const T b_value = ldg (b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
58
- GpuAtomicAdd (out_location, a_value * b_value);
68
+ const Tin b_value = ldg (b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
69
+ GpuAtomicAdd (out_location, (Tout)( a_value * b_value) );
59
70
}
60
71
}
61
72
62
73
namespace functor {
63
74
64
- template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
65
- struct SparseTensorDenseMatMulFunctor <GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
75
+ bool RequireDeterminism () {
76
+ static bool require_determinism = [] {
77
+ bool deterministic_ops = false ;
78
+ TF_CHECK_OK (tensorflow::ReadBoolFromEnvVar (" TF_DETERMINISTIC_OPS" ,
79
+ /* default_val=*/ false ,
80
+ &deterministic_ops));
81
+ return deterministic_ops;
82
+ }();
83
+ return require_determinism;
84
+ }
85
+
86
+ template <typename Tindices, bool ADJ_A, bool ADJ_B>
87
+ struct SparseTensorDenseMatMulFunctor <GPUDevice, float , Tindices, ADJ_A, ADJ_B> {
66
88
static EIGEN_ALWAYS_INLINE Status
67
- Compute (const GPUDevice& d , typename TTypes<T >::Matrix out,
89
+ Compute (OpKernelContext* context , typename TTypes<float >::Matrix out,
68
90
typename TTypes<Tindices>::ConstMatrix a_indices,
69
- typename TTypes<T >::ConstVec a_values,
70
- typename TTypes<T >::ConstMatrix b) {
71
- out. device (d) = out. constant ( T ( 0 ) );
91
+ typename TTypes<float >::ConstVec a_values,
92
+ typename TTypes<float >::ConstMatrix b) {
93
+ const GPUDevice d = context-> eigen_device <GPUDevice>( );
72
94
int nnz = a_values.size ();
73
95
// out = A * B, A is [m x n] and B is [n x p], out is [m x p]
74
96
int m = out.dimension (0 );
@@ -80,12 +102,38 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
80
102
// out.size()? Perhaps p * nnz ?
81
103
GpuLaunchConfig config = GetGpuLaunchConfig (p * nnz, d);
82
104
83
- TF_CHECK_OK (GpuLaunchKernel (
84
- SparseTensorDenseMatMulKernel<T, Tindices, ADJ_A, ADJ_B>,
85
- config.block_count , config.thread_per_block , 0 , d.stream (), nnz, m,
86
- b_rows, b_cols, p, a_indices.data (), a_values.data (), b.data (),
87
- out.data ()));
88
-
105
+ if (RequireDeterminism ()) {
106
+ Tensor temp_buffer;
107
+ TensorShape outshape ({m, p});
108
+
109
+ TF_RETURN_IF_ERROR (
110
+ context, context->allocate_temp (DT_DOUBLE, outshape, &temp_buffer));
111
+
112
+ TF_CHECK_OK (GpuLaunchKernel (
113
+ SetZero<double >, config.block_count , config.thread_per_block , 0 ,
114
+ d.stream (), m * p, (&temp_buffer)->flat <double >().data ()));
115
+
116
+ TF_CHECK_OK (GpuLaunchKernel (
117
+ SparseTensorDenseMatMulKernel<float , Tindices, double , ADJ_A, ADJ_B>,
118
+ config.block_count , config.thread_per_block , 0 , d.stream (),
119
+ nnz, m, b_rows, b_cols, p, a_indices.data (), a_values.data (),
120
+ b.data (), ((&temp_buffer)->matrix <double >()).data ()));
121
+
122
+ TF_CHECK_OK (GpuLaunchKernel (
123
+ DownCast, config.block_count , config.thread_per_block ,
124
+ 0 , d.stream (), m * p, ((&temp_buffer)->matrix <double >()).data (),
125
+ out.data ()));
126
+ } else {
127
+ TF_CHECK_OK (GpuLaunchKernel (
128
+ SetZero<float >, config.block_count , config.thread_per_block , 0 ,
129
+ d.stream (), m * p, out.data ()));
130
+
131
+ TF_CHECK_OK (GpuLaunchKernel (
132
+ SparseTensorDenseMatMulKernel<float , Tindices, float , ADJ_A, ADJ_B>,
133
+ config.block_count , config.thread_per_block , 0 , d.stream (), nnz, m,
134
+ b_rows, b_cols, p, a_indices.data (), a_values.data (), b.data (),
135
+ out.data ()));
136
+ }
89
137
return Status::OK ();
90
138
}
91
139
};
0 commit comments