@@ -77,12 +77,13 @@ void collapseDims(TensorInfo<T, T2>& info, Args&... infos) {
7777 collapseDims (infos...);
7878}
7979
80- #define DEVICE_LINEAR_GET (D_TENSOR, INDEX ) \
81- D_TENSOR.data[IndexToOffset<scalar_t , index_type>::get(INDEX, D_TENSOR)]
80+ #define DEVICE_LINEAR_GET (D_TENSOR, INDEX ) \
81+ D_TENSOR.data[IndexToOffset<scalar_t , index_type, indexing_kind>::get( \
82+ INDEX, D_TENSOR)]
8283
8384// Biases are always 1D
8485#define DEVICE_BIAS_GET (D_TENSOR, INDEX ) \
85- D_TENSOR.data[IndexToOffset<scalar_t , index_type>::get(INDEX, D_TENSOR)]
86+ D_TENSOR.data[IndexToOffset<scalar_t , index_type, 1 >::get(INDEX, D_TENSOR)]
8687
8788#define H2F (input ) static_cast <accscalar_t >(input)
8889#define F2H (input ) static_cast <scalar_t >(input)
@@ -93,7 +94,11 @@ inline T sigmoid(T in) {
9394 return one / (one + std::exp (-in));
9495}
9596
96- template <typename scalar_t , typename accscalar_t , typename index_type>
97+ template <
98+ typename scalar_t ,
99+ typename accscalar_t ,
100+ typename index_type,
101+ int indexing_kind>
97102struct LstmCellForwardFunctor {
98103 void operator ()(sycl::nd_item<1 > item) const {
99104 bool has_bias = bias1_.data != nullptr ;
@@ -205,7 +210,11 @@ struct LstmCellForwardFunctor {
205210 index_type totalElements_;
206211};
207212
208- template <typename scalar_t , typename accscalar_t , typename index_type>
213+ template <
214+ typename scalar_t ,
215+ typename accscalar_t ,
216+ typename index_type,
217+ int indexing_kind>
209218struct LstmCellBackwardFunctor {
210219 void operator ()(sycl::nd_item<1 > item) const {
211220 bool has_gradoutput = gradoutput_.data != nullptr ;
@@ -296,7 +305,11 @@ struct LstmCellBackwardFunctor {
296305 index_type totalElements_;
297306};
298307
299- template <typename scalar_t , typename accscalar_t , typename index_type>
308+ template <
309+ typename scalar_t ,
310+ typename accscalar_t ,
311+ typename index_type,
312+ int indexing_kind>
300313struct GruCellForwardFunctor {
301314 void operator ()(sycl::nd_item<1 > item) const {
302315 bool has_bias = Bias1_.data != nullptr ;
@@ -387,7 +400,11 @@ struct GruCellForwardFunctor {
387400 const index_type totalElements_;
388401};
389402
390- template <typename scalar_t , typename accscalar_t , typename index_type>
403+ template <
404+ typename scalar_t ,
405+ typename accscalar_t ,
406+ typename index_type,
407+ int indexing_kind>
391408struct GruCellBackwardFunctor {
392409 void operator ()(sycl::nd_item<1 > item) const {
393410 for (index_type linearIndex = item.get_global_id (0 );
@@ -469,12 +486,6 @@ void lstm_forward_impl(
469486 if (numel == 0 )
470487 return ;
471488
472- using KernelT = LstmCellForwardFunctor<scalar_t , accscalar_t , index_type>;
473- auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
474- auto config = rnn_get_launch_config (max_wg_size, numel);
475- auto nwg = std::get<0 >(config);
476- auto local_range = std::get<1 >(config);
477-
478489 auto input_gatesI = getTensorInfo<scalar_t , index_type>(input_gates);
479490 auto hidden_gatesI = getTensorInfo<scalar_t , index_type>(hidden_gates);
480491 auto input_biasI = tryGetTensorInfo<scalar_t , index_type>(input_bias);
@@ -503,6 +514,12 @@ void lstm_forward_impl(
503514 hyI,
504515 cyI,
505516 workspaceI);
517+ using KernelT =
518+ LstmCellForwardFunctor<scalar_t , accscalar_t , index_type, 1 >;
519+ auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
520+ auto config = rnn_get_launch_config (max_wg_size, numel);
521+ auto nwg = std::get<0 >(config);
522+ auto local_range = std::get<1 >(config);
506523 KernelT kfn (
507524 input_gatesI,
508525 hidden_gatesI,
@@ -517,6 +534,12 @@ void lstm_forward_impl(
517534 sycl_kernel_submit (
518535 nwg * local_range, local_range, getCurrentSYCLQueue (), kfn);
519536 } else {
537+ using KernelT =
538+ LstmCellForwardFunctor<scalar_t , accscalar_t , index_type, 2 >;
539+ auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
540+ auto config = rnn_get_launch_config (max_wg_size, numel);
541+ auto nwg = std::get<0 >(config);
542+ auto local_range = std::get<1 >(config);
520543 KernelT kfn (
521544 input_gatesI,
522545 hidden_gatesI,
@@ -548,12 +571,6 @@ void lstm_backward_impl(
548571 if (numel == 0 )
549572 return ;
550573
551- using KernelT = LstmCellBackwardFunctor<scalar_t , accscalar_t , index_type>;
552- auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
553- auto config = rnn_get_launch_config (max_wg_size, numel);
554- auto nwg = std::get<0 >(config);
555- auto local_range = std::get<1 >(config);
556-
557574 auto grad_hyI = tryGetTensorInfo<scalar_t , index_type>(grad_hy);
558575 auto grad_cyI = tryGetTensorInfo<scalar_t , index_type>(grad_cy);
559576 auto cxI = getTensorInfo<scalar_t , index_type>(cx);
@@ -567,6 +584,12 @@ void lstm_backward_impl(
567584 {grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx})) {
568585 collapseDims (
569586 grad_hyI, grad_cyI, cxI, cyI, workspaceI, grad_gatesI, grad_cxI);
587+ using KernelT =
588+ LstmCellBackwardFunctor<scalar_t , accscalar_t , index_type, 1 >;
589+ auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
590+ auto config = rnn_get_launch_config (max_wg_size, numel);
591+ auto nwg = std::get<0 >(config);
592+ auto local_range = std::get<1 >(config);
570593 KernelT kfn (
571594 workspaceI,
572595 grad_gatesI,
@@ -580,6 +603,12 @@ void lstm_backward_impl(
580603 sycl_kernel_submit (
581604 nwg * local_range, local_range, getCurrentSYCLQueue (), kfn);
582605 } else {
606+ using KernelT =
607+ LstmCellBackwardFunctor<scalar_t , accscalar_t , index_type, 2 >;
608+ auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
609+ auto config = rnn_get_launch_config (max_wg_size, numel);
610+ auto nwg = std::get<0 >(config);
611+ auto local_range = std::get<1 >(config);
583612 KernelT kfn (
584613 workspaceI,
585614 grad_gatesI,
@@ -610,12 +639,6 @@ void gru_forward_impl(
610639 if (numel == 0 )
611640 return ;
612641
613- using KernelT = GruCellForwardFunctor<scalar_t , accscalar_t , index_type>;
614- auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
615- auto config = rnn_get_launch_config (max_wg_size, numel);
616- auto nwg = std::get<0 >(config);
617- auto local_range = std::get<1 >(config);
618-
619642 auto input_gatesI = getTensorInfo<scalar_t , index_type>(input_gates);
620643 auto hidden_gatesI = getTensorInfo<scalar_t , index_type>(hidden_gates);
621644 auto input_biasI = tryGetTensorInfo<scalar_t , index_type>(input_bias);
@@ -641,6 +664,11 @@ void gru_forward_impl(
641664 hxI,
642665 hyI,
643666 workspaceI);
667+ using KernelT = GruCellForwardFunctor<scalar_t , accscalar_t , index_type, 1 >;
668+ auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
669+ auto config = rnn_get_launch_config (max_wg_size, numel);
670+ auto nwg = std::get<0 >(config);
671+ auto local_range = std::get<1 >(config);
644672 KernelT kfn (
645673 input_gatesI,
646674 hidden_gatesI,
@@ -654,6 +682,11 @@ void gru_forward_impl(
654682 sycl_kernel_submit (
655683 nwg * local_range, local_range, getCurrentSYCLQueue (), kfn);
656684 } else {
685+ using KernelT = GruCellForwardFunctor<scalar_t , accscalar_t , index_type, 2 >;
686+ auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
687+ auto config = rnn_get_launch_config (max_wg_size, numel);
688+ auto nwg = std::get<0 >(config);
689+ auto local_range = std::get<1 >(config);
657690 KernelT kfn (
658691 input_gatesI,
659692 hidden_gatesI,
@@ -682,12 +715,6 @@ void gru_backward_impl(
682715 if (numel == 0 )
683716 return ;
684717
685- using KernelT = GruCellBackwardFunctor<scalar_t , accscalar_t , index_type>;
686- auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
687- auto config = rnn_get_launch_config (max_wg_size, numel);
688- auto nwg = std::get<0 >(config);
689- auto local_range = std::get<1 >(config);
690-
691718 auto grad_hyI = getTensorInfo<scalar_t , index_type>(grad_hy);
692719 auto workspaceI = getTensorInfo<scalar_t , index_type>(workspace);
693720 auto grad_input_gatesI =
@@ -701,6 +728,12 @@ void gru_backward_impl(
701728 {grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx})) {
702729 collapseDims (
703730 grad_hyI, workspaceI, grad_input_gatesI, grad_hidden_gatesI, grad_hxI);
731+ using KernelT =
732+ GruCellBackwardFunctor<scalar_t , accscalar_t , index_type, 1 >;
733+ auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
734+ auto config = rnn_get_launch_config (max_wg_size, numel);
735+ auto nwg = std::get<0 >(config);
736+ auto local_range = std::get<1 >(config);
704737 KernelT kfn (
705738 grad_input_gatesI,
706739 grad_hidden_gatesI,
@@ -712,6 +745,12 @@ void gru_backward_impl(
712745 sycl_kernel_submit (
713746 nwg * local_range, local_range, getCurrentSYCLQueue (), kfn);
714747 } else {
748+ using KernelT =
749+ GruCellBackwardFunctor<scalar_t , accscalar_t , index_type, 2 >;
750+ auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
751+ auto config = rnn_get_launch_config (max_wg_size, numel);
752+ auto nwg = std::get<0 >(config);
753+ auto local_range = std::get<1 >(config);
715754 KernelT kfn (
716755 grad_input_gatesI,
717756 grad_hidden_gatesI,
0 commit comments