Skip to content

Commit ca12ef3

Browse files
ville-kbenoitsteiner
authored andcommitted
Register batch normalization kernels for OpenCL (#61)
1 parent f113ef0 commit ca12ef3

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tensorflow/core/kernels/batch_norm_op.cc

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ namespace tensorflow {
2828

2929
typedef Eigen::ThreadPoolDevice CPUDevice;
3030
typedef Eigen::GpuDevice GPUDevice;
31+
#ifdef TENSORFLOW_USE_SYCL
32+
typedef Eigen::SyclDevice SYCLDevice;
33+
#endif // TENSORFLOW_USE_SYCL
3134

3235
template <typename Device, typename T>
3336
class BatchNormOp : public OpKernel {
@@ -207,6 +210,18 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
207210

208211
#endif // GOOGLE_CUDA
209212

213+
#if TENSORFLOW_USE_SYCL
214+
#define REGISTER_KERNEL(T) \
215+
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalization") \
216+
.Device(DEVICE_SYCL) \
217+
.TypeConstraint<T>("T"), \
218+
BatchNormOp<SYCLDevice, T>);
219+
220+
TF_CALL_float(REGISTER_KERNEL);
221+
TF_CALL_double(REGISTER_KERNEL);
222+
#undef REGISTER_KERNEL
223+
#endif // TENSORFLOW_USE_SYCL
224+
210225
#define REGISTER_KERNEL(T) \
211226
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
212227
.Device(DEVICE_CPU) \
@@ -254,4 +269,17 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
254269

255270
#endif // GOOGLE_CUDA
256271

272+
#if TENSORFLOW_USE_SYCL
273+
#define REGISTER_KERNEL(T) \
274+
REGISTER_KERNEL_BUILDER(Name("BatchNormWithGlobalNormalizationGrad") \
275+
.Device(DEVICE_SYCL) \
276+
.TypeConstraint<T>("T"), \
277+
BatchNormGradOp<SYCLDevice, T>);
278+
279+
TF_CALL_float(REGISTER_KERNEL);
280+
TF_CALL_double(REGISTER_KERNEL);
281+
#undef REGISTER_KERNEL
282+
283+
#endif // TENSORFLOW_USE_SYCL
284+
257285
} // namespace tensorflow

0 commit comments

Comments
 (0)