@@ -38,6 +38,9 @@ limitations under the License.
3838namespace tensorflow {
3939
4040typedef Eigen::ThreadPoolDevice CPUDevice;
41+ #ifdef TENSORFLOW_USE_SYCL
42+ typedef Eigen::SyclDevice SYCLDevice;
43+ #endif // TENSORFLOW_USE_SYCL
4144
4245class OpKernelContext ;
4346
@@ -186,6 +189,91 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_MATH)
186189#undef REGISTER_SCATTER_ND_INDEX
187190#undef REGISTER_SCATTER_ND_FULL
188191
192+ #ifdef TENSORFLOW_USE_SYCL
193+
194+ // Implementation of update functor for SYCL.
195+ template <typename T, typename Index, scatter_nd_op::UpdateOp OP, int IXDIM>
196+ struct ScatterNdFunctor <SYCLDevice, T, Index, OP, IXDIM> {
197+ Index operator ()(
198+ const SYCLDevice& d, const Index slice_size,
199+ const Eigen::array<Eigen::DenseIndex, IXDIM> output_shape_prefix,
200+ typename TTypes<T, 2 >::Tensor Tparams,
201+ typename TTypes<Index, 2 >::ConstTensor Tindices,
202+ typename TTypes<T, 2 >::ConstTensor Tupdates,
203+ typename TTypes<T, 2 >::Tensor Toutput) {
204+ // error_loc is -1 if there's no out-of-bounds index,
205+ // otherwise it is the location of an OOB index in Tindices.
206+ Index error_loc = -1 ;
207+
208+ const Eigen::DenseIndex batch_size = Tindices.dimension (0 );
209+
210+ Index batch_strides[IXDIM];
211+ for (int dim = IXDIM - 1 ; dim >= 0 ; --dim) {
212+ if (dim == IXDIM - 1 ) {
213+ batch_strides[dim] = 1 ;
214+ } else {
215+ batch_strides[dim] =
216+ batch_strides[dim + 1 ] * output_shape_prefix[dim + 1 ];
217+ }
218+ }
219+
220+ for (Eigen::DenseIndex loc = 0 ; loc < batch_size; ++loc) {
221+ Index i = 0 ;
222+ bool out_of_bounds = false ;
223+ for (int dim = 0 ; dim < IXDIM; ++dim) {
224+ const Index ix_d = internal::SubtleMustCopy (Tindices (loc, dim));
225+ out_of_bounds |= !FastBoundsCheck (ix_d, output_shape_prefix[dim]);
226+ i += ix_d * batch_strides[dim];
227+ }
228+ if (TF_PREDICT_FALSE (out_of_bounds)) {
229+ error_loc = loc;
230+ break ;
231+ } else {
232+ auto input_chip = Toutput.template chip <0 >(i);
233+ auto output_chip = input_chip.device (d);
234+ auto update_chip = Tupdates.template chip <0 >(loc);
235+ update_executor::UpdateExecutor<
236+ decltype (input_chip), decltype (update_chip), decltype (output_chip),
237+ OP>::Execute (input_chip, update_chip, output_chip);
238+ }
239+ }
240+
241+ return error_loc;
242+ }
243+ };
244+
245+ #define REGISTER_SCATTER_ND_FULL_SYCL (T, Index, op ) \
246+ template Index \
247+ ScatterNdFunctor<SYCLDevice, T, Index, op, CPU_PROVIDED_IXDIM>::operator ()( \
248+ const SYCLDevice& d, const Index slice_size, \
249+ const Eigen::array<Eigen::DenseIndex, CPU_PROVIDED_IXDIM> \
250+ output_shape_prefix, \
251+ typename TTypes<T, 2 >::Tensor Tparams, \
252+ typename TTypes<Index, 2 >::ConstTensor Tindices, \
253+ typename TTypes<T, 2 >::ConstTensor Tupdates, \
254+ typename TTypes<T, 2 >::Tensor Toutput)
255+
256+ #define REGISTER_SCATTER_ND_INDEX_SYCL (type, op ) \
257+ REGISTER_SCATTER_ND_FULL_SYCL (type, int32, op); \
258+ REGISTER_SCATTER_ND_FULL_SYCL (type, int64, op)
259+
260+ #define REGISTER_SCATTER_ND_UPDATE_SYCL (type ) \
261+ REGISTER_SCATTER_ND_INDEX_SYCL (type, scatter_nd_op::UpdateOp::ASSIGN);
262+
263+ #define REGISTER_SCATTER_ND_MATH_SYCL (type ) \
264+ REGISTER_SCATTER_ND_INDEX_SYCL (type, scatter_nd_op::UpdateOp::ADD); \
265+ REGISTER_SCATTER_ND_INDEX_SYCL (type, scatter_nd_op::UpdateOp::SUB);
266+
267+ TF_CALL_GPU_NUMBER_TYPES_NO_HALF (REGISTER_SCATTER_ND_UPDATE_SYCL)
268+ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_MATH_SYCL)
269+
270+ #undef REGISTER_SCATTER_ND_MATH_SYCL
271+ #undef REGISTER_SCATTER_ND_UPDATE_SYCL
272+ #undef REGISTER_SCATTER_ND_INDEX_SYCL
273+ #undef REGISTER_SCATTER_ND_FULL_SYCL
274+
275+ #endif // TENSORFLOW_USE_SYCL
276+
189277} // namespace functor
190278
191279} // namespace tensorflow
0 commit comments