@@ -20,18 +20,29 @@ limitations under the License.
2020#include " tensorflow/core/framework/bounds_check.h"
2121#include " tensorflow/core/framework/register_types.h"
2222#include " tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
23+ #include " tensorflow/core/util/env_var.h"
24+ #include " tensorflow/core/util/gpu_device_functions.h"
2325#include " tensorflow/core/util/gpu_kernel_helper.h"
2426
2527namespace tensorflow {
2628
2729typedef Eigen::GpuDevice GPUDevice;
2830
29- template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
31+ __global__ void DownCast (
32+ const int size, const double * src, float * __restrict__ dst) {
33+
34+ GPU_1D_KERNEL_LOOP (index, size) {
35+ dst[index] = (float )src[index];
36+ }
37+ }
38+
39+ template <typename Tin, typename Tindices, typename Tout,
40+ bool ADJ_A, bool ADJ_B>
3041__global__ void SparseTensorDenseMatMulKernel (int nnz, int m, int b_rows,
3142 int b_cols, int p,
3243 const Tindices* a_indices,
33- const T * a_values, const T * b,
34- T * out) {
44+ const Tin * a_values, const Tin * b,
45+ Tout * out) {
3546 // out_{ij} = sum_k {a_ik b_kj}
3647 // out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
3748 const int n = (ADJ_B) ? b_cols : b_rows;
@@ -44,31 +55,42 @@ __global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
4455 continue ; // Nowhere to signal an error :(
4556 }
4657 // out[i, j]
47- T * out_location = out + i * p + j;
58+ Tout * out_location = out + i * p + j;
4859 if (!FastBoundsCheck (k, n)) {
49- GpuAtomicAdd (out_location, std::numeric_limits<T >::quiet_NaN ());
60+ GpuAtomicAdd (out_location, std::numeric_limits<Tout >::quiet_NaN ());
5061 continue ;
5162 }
5263
5364 // a_value == (ADJ_A) ? a[k, i] : a[i, k]
54- const T a_value = ldg (a_values + a_ix);
65+ const Tin a_value = ldg (a_values + a_ix);
5566
5667 // b_value == (ADJ_B) ? b[j, k] : b[k, j]
57- const T b_value = ldg (b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
58- GpuAtomicAdd (out_location, a_value * b_value);
68+ const Tin b_value = ldg (b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
69+ GpuAtomicAdd (out_location, (Tout)( a_value * b_value) );
5970 }
6071}
6172
6273namespace functor {
6374
64- template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
65- struct SparseTensorDenseMatMulFunctor <GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
75+ bool RequireDeterminism () {
76+ static bool require_determinism = [] {
77+ bool deterministic_ops = false ;
78+ TF_CHECK_OK (tensorflow::ReadBoolFromEnvVar (" TF_DETERMINISTIC_OPS" ,
79+ /* default_val=*/ false ,
80+ &deterministic_ops));
81+ return deterministic_ops;
82+ }();
83+ return require_determinism;
84+ }
85+
86+ template <typename Tindices, bool ADJ_A, bool ADJ_B>
87+ struct SparseTensorDenseMatMulFunctor <GPUDevice, float , Tindices, ADJ_A, ADJ_B> {
6688 static EIGEN_ALWAYS_INLINE Status
67- Compute (const GPUDevice& d , typename TTypes<T >::Matrix out,
89+ Compute (OpKernelContext* context , typename TTypes<float >::Matrix out,
6890 typename TTypes<Tindices>::ConstMatrix a_indices,
69- typename TTypes<T >::ConstVec a_values,
70- typename TTypes<T >::ConstMatrix b) {
71- out. device (d) = out. constant ( T ( 0 ) );
91+ typename TTypes<float >::ConstVec a_values,
92+ typename TTypes<float >::ConstMatrix b) {
93+ const GPUDevice d = context-> eigen_device <GPUDevice>( );
7294 int nnz = a_values.size ();
7395 // out = A * B, A is [m x n] and B is [n x p], out is [m x p]
7496 int m = out.dimension (0 );
@@ -80,12 +102,38 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
80102 // out.size()? Perhaps p * nnz ?
81103 GpuLaunchConfig config = GetGpuLaunchConfig (p * nnz, d);
82104
83- TF_CHECK_OK (GpuLaunchKernel (
84- SparseTensorDenseMatMulKernel<T, Tindices, ADJ_A, ADJ_B>,
85- config.block_count , config.thread_per_block , 0 , d.stream (), nnz, m,
86- b_rows, b_cols, p, a_indices.data (), a_values.data (), b.data (),
87- out.data ()));
88-
105+ if (RequireDeterminism ()) {
106+ Tensor temp_buffer;
107+ TensorShape outshape ({m, p});
108+
109+ TF_RETURN_IF_ERROR (
110+ context, context->allocate_temp (DT_DOUBLE, outshape, &temp_buffer));
111+
112+ TF_CHECK_OK (GpuLaunchKernel (
113+ SetZero<double >, config.block_count , config.thread_per_block , 0 ,
114+ d.stream (), m * p, (&temp_buffer)->flat <double >().data ()));
115+
116+ TF_CHECK_OK (GpuLaunchKernel (
117+ SparseTensorDenseMatMulKernel<float , Tindices, double , ADJ_A, ADJ_B>,
118+ config.block_count , config.thread_per_block , 0 , d.stream (),
119+ nnz, m, b_rows, b_cols, p, a_indices.data (), a_values.data (),
120+ b.data (), ((&temp_buffer)->matrix <double >()).data ()));
121+
122+ TF_CHECK_OK (GpuLaunchKernel (
123+ DownCast, config.block_count , config.thread_per_block ,
124+ 0 , d.stream (), m * p, ((&temp_buffer)->matrix <double >()).data (),
125+ out.data ()));
126+ } else {
127+ TF_CHECK_OK (GpuLaunchKernel (
128+ SetZero<float >, config.block_count , config.thread_per_block , 0 ,
129+ d.stream (), m * p, out.data ()));
130+
131+ TF_CHECK_OK (GpuLaunchKernel (
132+ SparseTensorDenseMatMulKernel<float , Tindices, float , ADJ_A, ADJ_B>,
133+ config.block_count , config.thread_per_block , 0 , d.stream (), nnz, m,
134+ b_rows, b_cols, p, a_indices.data (), a_values.data (), b.data (),
135+ out.data ()));
136+ }
89137 return Status::OK ();
90138 }
91139};
0 commit comments