Skip to content

Commit 623b18d

Browse files
wenscarlnluehr
authored andcommitted
[determinism] Add deterministic tf.sparse.sparse_dense_matmul GPU kernel
1 parent a645904 commit 623b18d

8 files changed

+631
-448
lines changed

tensorflow/core/kernels/sparse_tensor_dense_matmul_op.cc

+4-3
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class SparseTensorDenseMatMulOp : public OpKernel {
137137
if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \
138138
Status functor_status = functor::SparseTensorDenseMatMulFunctor< \
139139
Device, T, Tindices, ADJ_A, \
140-
ADJ_B>::Compute(ctx->eigen_device<Device>(), out->matrix<T>(), \
140+
ADJ_B>::Compute(ctx, out->matrix<T>(), \
141141
a_indices->matrix<Tindices>(), a_values->vec<T>(), \
142142
b->matrix<T>()); \
143143
OP_REQUIRES_OK(ctx, functor_status); \
@@ -183,7 +183,7 @@ namespace functor {
183183
template <> \
184184
Status SparseTensorDenseMatMulFunctor< \
185185
GPUDevice, T, Tindices, ADJ_A, \
186-
ADJ_B>::Compute(const GPUDevice& d, typename TTypes<T>::Matrix out, \
186+
ADJ_B>::Compute(OpKernelContext* ctx, typename TTypes<T>::Matrix out, \
187187
TTypes<Tindices>::ConstMatrix a_indices, \
188188
typename TTypes<T>::ConstVec a_values, \
189189
typename TTypes<T>::ConstMatrix b); \
@@ -246,10 +246,11 @@ struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> {
246246
// Vectorize certain operations above this size.
247247
static const std::size_t kNumVectorize = 32;
248248

249-
static Status Compute(const CPUDevice& d, typename TTypes<T>::Matrix out,
249+
static Status Compute(OpKernelContext* context, typename TTypes<T>::Matrix out,
250250
typename TTypes<Tindices>::ConstMatrix a_indices,
251251
typename TTypes<T>::ConstVec a_values,
252252
typename TTypes<T>::ConstMatrix b) {
253+
const CPUDevice d = context->eigen_device<CPUDevice>();
253254
const std::size_t nnz = a_values.size();
254255
const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1));
255256
const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0));

tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
#define TENSORFLOW_CORE_KERNELS_SPARSE_TENSOR_DENSE_MATMUL_OP_H_
1818

1919
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20+
#include "tensorflow/core/framework/op_kernel.h"
2021
#include "tensorflow/core/framework/tensor_types.h"
2122
#include "tensorflow/core/framework/types.h"
2223
#include "tensorflow/core/lib/core/errors.h"
@@ -29,7 +30,7 @@ template <typename Device, typename T, typename Tindices, bool ADJ_A,
2930
bool ADJ_B>
3031
struct SparseTensorDenseMatMulFunctor {
3132
static EIGEN_ALWAYS_INLINE Status Compute(
32-
const Device& d, typename TTypes<T>::Matrix out,
33+
OpKernelContext* context, typename TTypes<T>::Matrix out,
3334
typename TTypes<Tindices>::ConstMatrix a_indices,
3435
typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b);
3536
};

tensorflow/core/kernels/sparse_tensor_dense_matmul_op_gpu.cu.cc

+68-20
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,29 @@ limitations under the License.
2020
#include "tensorflow/core/framework/bounds_check.h"
2121
#include "tensorflow/core/framework/register_types.h"
2222
#include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h"
23+
#include "tensorflow/core/util/env_var.h"
24+
#include "tensorflow/core/util/gpu_device_functions.h"
2325
#include "tensorflow/core/util/gpu_kernel_helper.h"
2426

2527
namespace tensorflow {
2628

2729
typedef Eigen::GpuDevice GPUDevice;
2830

29-
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
31+
__global__ void DownCast(
32+
const int size, const double* src, float* __restrict__ dst) {
33+
34+
GPU_1D_KERNEL_LOOP(index, size) {
35+
dst[index] = (float)src[index];
36+
}
37+
}
38+
39+
template <typename Tin, typename Tindices, typename Tout,
40+
bool ADJ_A, bool ADJ_B>
3041
__global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
3142
int b_cols, int p,
3243
const Tindices* a_indices,
33-
const T* a_values, const T* b,
34-
T* out) {
44+
const Tin* a_values, const Tin* b,
45+
Tout* out) {
3546
// out_{ij} = sum_k {a_ik b_kj}
3647
// out = A * B', out_{ij} = sum_k {a_ik (b')_kj}; b'_{kj} = b_{jk}
3748
const int n = (ADJ_B) ? b_cols : b_rows;
@@ -44,31 +55,42 @@ __global__ void SparseTensorDenseMatMulKernel(int nnz, int m, int b_rows,
4455
continue; // Nowhere to signal an error :(
4556
}
4657
// out[i, j]
47-
T* out_location = out + i * p + j;
58+
Tout* out_location = out + i * p + j;
4859
if (!FastBoundsCheck(k, n)) {
49-
GpuAtomicAdd(out_location, std::numeric_limits<T>::quiet_NaN());
60+
GpuAtomicAdd(out_location, std::numeric_limits<Tout>::quiet_NaN());
5061
continue;
5162
}
5263

5364
// a_value == (ADJ_A) ? a[k, i] : a[i, k]
54-
const T a_value = ldg(a_values + a_ix);
65+
const Tin a_value = ldg(a_values + a_ix);
5566

5667
// b_value == (ADJ_B) ? b[j, k] : b[k, j]
57-
const T b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
58-
GpuAtomicAdd(out_location, a_value * b_value);
68+
const Tin b_value = ldg(b + ((ADJ_B) ? j * b_cols + k : k * b_cols + j));
69+
GpuAtomicAdd(out_location, (Tout)(a_value * b_value));
5970
}
6071
}
6172

6273
namespace functor {
6374

64-
template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B>
65-
struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
75+
bool RequireDeterminism() {
76+
static bool require_determinism = [] {
77+
bool deterministic_ops = false;
78+
TF_CHECK_OK(tensorflow::ReadBoolFromEnvVar("TF_DETERMINISTIC_OPS",
79+
/*default_val=*/false,
80+
&deterministic_ops));
81+
return deterministic_ops;
82+
}();
83+
return require_determinism;
84+
}
85+
86+
template <typename Tindices, bool ADJ_A, bool ADJ_B>
87+
struct SparseTensorDenseMatMulFunctor<GPUDevice, float, Tindices, ADJ_A, ADJ_B> {
6688
static EIGEN_ALWAYS_INLINE Status
67-
Compute(const GPUDevice& d, typename TTypes<T>::Matrix out,
89+
Compute(OpKernelContext* context, typename TTypes<float>::Matrix out,
6890
typename TTypes<Tindices>::ConstMatrix a_indices,
69-
typename TTypes<T>::ConstVec a_values,
70-
typename TTypes<T>::ConstMatrix b) {
71-
out.device(d) = out.constant(T(0));
91+
typename TTypes<float>::ConstVec a_values,
92+
typename TTypes<float>::ConstMatrix b) {
93+
const GPUDevice d = context->eigen_device<GPUDevice>();
7294
int nnz = a_values.size();
7395
// out = A * B, A is [m x n] and B is [n x p], out is [m x p]
7496
int m = out.dimension(0);
@@ -80,12 +102,38 @@ struct SparseTensorDenseMatMulFunctor<GPUDevice, T, Tindices, ADJ_A, ADJ_B> {
80102
// out.size()? Perhaps p * nnz ?
81103
GpuLaunchConfig config = GetGpuLaunchConfig(p * nnz, d);
82104

83-
TF_CHECK_OK(GpuLaunchKernel(
84-
SparseTensorDenseMatMulKernel<T, Tindices, ADJ_A, ADJ_B>,
85-
config.block_count, config.thread_per_block, 0, d.stream(), nnz, m,
86-
b_rows, b_cols, p, a_indices.data(), a_values.data(), b.data(),
87-
out.data()));
88-
105+
if (RequireDeterminism()) {
106+
Tensor temp_buffer;
107+
TensorShape outshape({m, p});
108+
109+
TF_RETURN_IF_ERROR(
110+
context, context->allocate_temp(DT_DOUBLE, outshape, &temp_buffer));
111+
112+
TF_CHECK_OK(GpuLaunchKernel(
113+
SetZero<double>, config.block_count, config.thread_per_block, 0,
114+
d.stream(), m * p, (&temp_buffer)->flat<double>().data()));
115+
116+
TF_CHECK_OK(GpuLaunchKernel(
117+
SparseTensorDenseMatMulKernel<float, Tindices, double, ADJ_A, ADJ_B>,
118+
config.block_count, config.thread_per_block, 0, d.stream(),
119+
nnz, m, b_rows, b_cols, p, a_indices.data(), a_values.data(),
120+
b.data(), ((&temp_buffer)->matrix<double>()).data()));
121+
122+
TF_CHECK_OK(GpuLaunchKernel(
123+
DownCast, config.block_count, config.thread_per_block,
124+
0, d.stream(), m * p, ((&temp_buffer)->matrix<double>()).data(),
125+
out.data()));
126+
} else {
127+
TF_CHECK_OK(GpuLaunchKernel(
128+
SetZero<float>, config.block_count, config.thread_per_block, 0,
129+
d.stream(), m * p, out.data()));
130+
131+
TF_CHECK_OK(GpuLaunchKernel(
132+
SparseTensorDenseMatMulKernel<float, Tindices, float, ADJ_A, ADJ_B>,
133+
config.block_count, config.thread_per_block, 0, d.stream(), nnz, m,
134+
b_rows, b_cols, p, a_indices.data(), a_values.data(), b.data(),
135+
out.data()));
136+
}
89137
return Status::OK();
90138
}
91139
};

tensorflow/python/kernel_tests/BUILD

+23-5
Original file line numberDiff line numberDiff line change
@@ -3267,11 +3267,10 @@ cuda_py_test(
32673267
xla_enable_strict_auto_jit = True,
32683268
)
32693269

3270-
cuda_py_test(
3271-
name = "sparse_tensor_dense_matmul_op_test",
3272-
size = "medium",
3273-
srcs = ["sparse_tensor_dense_matmul_op_test.py"],
3274-
additional_deps = [
3270+
py_library(
3271+
name = "sparse_tensor_dense_matmul_op_base",
3272+
srcs = ["sparse_tensor_dense_matmul_op_base.py"],
3273+
deps = [
32753274
"//third_party/py/numpy",
32763275
"//tensorflow/core:protos_all_py",
32773276
"//tensorflow/python:array_ops",
@@ -3284,6 +3283,25 @@ cuda_py_test(
32843283
"//tensorflow/python:platform",
32853284
"//tensorflow/python:sparse_ops",
32863285
],
3286+
)
3287+
3288+
cuda_py_test(
3289+
name = "sparse_tensor_dense_matmul_op_test",
3290+
size = "medium",
3291+
srcs = ["sparse_tensor_dense_matmul_op_test.py"],
3292+
additional_deps = [
3293+
":sparse_tensor_dense_matmul_op_base",
3294+
],
3295+
xla_enable_strict_auto_jit = True,
3296+
)
3297+
3298+
cuda_py_test(
3299+
name = "sparse_tensor_dense_matmul_op_deterministic_test",
3300+
size = "small",
3301+
srcs = ["sparse_tensor_dense_matmul_op_deterministic_test.py"],
3302+
additional_deps = [
3303+
":sparse_tensor_dense_matmul_op_base",
3304+
],
32873305
xla_enable_strict_auto_jit = True,
32883306
)
32893307

0 commit comments

Comments
 (0)