@@ -28,6 +28,9 @@ namespace tensorflow {
2828
2929typedef Eigen::ThreadPoolDevice CPUDevice;
3030typedef Eigen::GpuDevice GPUDevice;
31+ #ifdef TENSORFLOW_USE_SYCL
32+ typedef Eigen::SyclDevice SYCLDevice;
33+ #endif // TENSORFLOW_USE_SYCL
3134
3235template <typename Device, typename T>
3336class BatchNormOp : public OpKernel {
@@ -207,6 +210,18 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
207210
208211#endif // GOOGLE_CUDA
209212
213+ #if TENSORFLOW_USE_SYCL
214+ #define REGISTER_KERNEL (T ) \
215+ REGISTER_KERNEL_BUILDER (Name(" BatchNormWithGlobalNormalization" ) \
216+ .Device(DEVICE_SYCL) \
217+ .TypeConstraint<T>(" T" ), \
218+ BatchNormOp<SYCLDevice, T>);
219+
220+ TF_CALL_float (REGISTER_KERNEL);
221+ TF_CALL_double (REGISTER_KERNEL);
222+ #undef REGISTER_KERNEL
223+ #endif // TENSORFLOW_USE_SYCL
224+
210225#define REGISTER_KERNEL (T ) \
211226 REGISTER_KERNEL_BUILDER (Name(" BatchNormWithGlobalNormalizationGrad" ) \
212227 .Device(DEVICE_CPU) \
@@ -254,4 +269,17 @@ TF_CALL_float(REGISTER_GPU_KERNEL);
254269
255270#endif // GOOGLE_CUDA
256271
272+ #if TENSORFLOW_USE_SYCL
273+ #define REGISTER_KERNEL (T ) \
274+ REGISTER_KERNEL_BUILDER (Name(" BatchNormWithGlobalNormalizationGrad" ) \
275+ .Device(DEVICE_SYCL) \
276+ .TypeConstraint<T>(" T" ), \
277+ BatchNormGradOp<SYCLDevice, T>);
278+
279+ TF_CALL_float (REGISTER_KERNEL);
280+ TF_CALL_double (REGISTER_KERNEL);
281+ #undef REGISTER_KERNEL
282+
283+ #endif // TENSORFLOW_USE_SYCL
284+
257285} // namespace tensorflow
0 commit comments