Skip to content

Commit 369a0c9

Browse files
Split IndextoOffset() into offline and online versions (#2136)
Divide indextoOffset() into two versions, offline and online, to reduce runtime overhead and as much as possible. --------- Co-authored-by: mengfei25 <[email protected]>
1 parent 2d4323d commit 369a0c9

File tree

11 files changed

+484
-396
lines changed

11 files changed

+484
-396
lines changed

src/ATen/native/xpu/sycl/Dropout.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ struct FusedDropoutUnrollFunctor {
165165
if (li < total_elements_) {
166166
// Convert `linearIndex` into an offset of `a`
167167
const IndexType aOffset =
168-
IndexToOffset<const scalar_t, IndexType>::get(li, a_);
168+
IndexToOffset<const scalar_t, IndexType, ADims>::get(li, a_);
169169
src[ii] = a_.data[aOffset];
170170
}
171171
}
@@ -174,7 +174,7 @@ struct FusedDropoutUnrollFunctor {
174174
if (li < total_elements_) {
175175
// Convert `linearIndex` into an offset of `b`
176176
const IndexType bOffset =
177-
IndexToOffset<scalar_t, IndexType>::get(li, b_);
177+
IndexToOffset<scalar_t, IndexType, BDims>::get(li, b_);
178178
b_.data[bOffset] = src[ii] * (&rand.x)[ii] * scale;
179179
c_.data[bOffset] = (mask_t)(&rand.x)[ii];
180180
}

src/ATen/native/xpu/sycl/Indexing.cpp

Lines changed: 218 additions & 50 deletions
Large diffs are not rendered by default.

src/ATen/native/xpu/sycl/Indexing.h

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -211,10 +211,8 @@ class IndexKernel {
211211
if constexpr (TrivialOffCal) {
212212
idx_off = idx_logical_off;
213213
} else {
214-
idx_off = IndexToOffset<IdxType, int64_t>::get(
215-
idx_logical_off,
216-
cfg_.iinfo_,
217-
IndexToOffset<IdxType, int64_t>::NON_STRICT_CONTIGUOUS);
214+
idx_off = IndexToOffset<IdxType, int64_t, -1>::get(
215+
idx_logical_off, cfg_.iinfo_);
218216
}
219217
glb_batch_group = id.glb_batch / cfg_.index_num_;
220218
glb_batch_group_loc_off = cfg_.iinfo_.data[idx_off];
@@ -322,26 +320,18 @@ class IndexKernel {
322320
} else {
323321
if (cfg_.indexing_dst_) {
324322
// index_copy, index_add, index_fill
325-
dst_off = IndexToOffset<ValType, int64_t>::get(
326-
glb_indexing_logical_off,
327-
cfg_.dinfo_,
328-
IndexToOffset<ValType, int64_t>::NON_STRICT_CONTIGUOUS);
323+
dst_off = IndexToOffset<ValType, int64_t, -1>::get(
324+
glb_indexing_logical_off, cfg_.dinfo_);
329325
if (cfg_.sinfo_.data != nullptr) {
330-
src_off = IndexToOffset<const ValType, int64_t>::get(
331-
glb_fixing_logical_off,
332-
cfg_.sinfo_,
333-
IndexToOffset<const ValType, int64_t>::NON_STRICT_CONTIGUOUS);
326+
src_off = IndexToOffset<const ValType, int64_t, -1>::get(
327+
glb_fixing_logical_off, cfg_.sinfo_);
334328
}
335329
} else {
336330
// index_select
337-
src_off = IndexToOffset<const ValType, int64_t>::get(
338-
glb_indexing_logical_off,
339-
cfg_.sinfo_,
340-
IndexToOffset<const ValType, int64_t>::NON_STRICT_CONTIGUOUS);
341-
dst_off = IndexToOffset<ValType, int64_t>::get(
342-
glb_fixing_logical_off,
343-
cfg_.dinfo_,
344-
IndexToOffset<ValType, int64_t>::NON_STRICT_CONTIGUOUS);
331+
src_off = IndexToOffset<const ValType, int64_t, -1>::get(
332+
glb_indexing_logical_off, cfg_.sinfo_);
333+
dst_off = IndexToOffset<ValType, int64_t, -1>::get(
334+
glb_fixing_logical_off, cfg_.dinfo_);
345335
}
346336
}
347337
cfg_.func_(

src/ATen/native/xpu/sycl/RNNKernels.cpp

Lines changed: 70 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,13 @@ void collapseDims(TensorInfo<T, T2>& info, Args&... infos) {
7777
collapseDims(infos...);
7878
}
7979

80-
#define DEVICE_LINEAR_GET(D_TENSOR, INDEX) \
81-
D_TENSOR.data[IndexToOffset<scalar_t, index_type>::get(INDEX, D_TENSOR)]
80+
#define DEVICE_LINEAR_GET(D_TENSOR, INDEX) \
81+
D_TENSOR.data[IndexToOffset<scalar_t, index_type, indexing_kind>::get( \
82+
INDEX, D_TENSOR)]
8283

8384
// Biases are always 1D
8485
#define DEVICE_BIAS_GET(D_TENSOR, INDEX) \
85-
D_TENSOR.data[IndexToOffset<scalar_t, index_type>::get(INDEX, D_TENSOR)]
86+
D_TENSOR.data[IndexToOffset<scalar_t, index_type, 1>::get(INDEX, D_TENSOR)]
8687

8788
#define H2F(input) static_cast<accscalar_t>(input)
8889
#define F2H(input) static_cast<scalar_t>(input)
@@ -93,7 +94,11 @@ inline T sigmoid(T in) {
9394
return one / (one + std::exp(-in));
9495
}
9596

96-
template <typename scalar_t, typename accscalar_t, typename index_type>
97+
template <
98+
typename scalar_t,
99+
typename accscalar_t,
100+
typename index_type,
101+
int indexing_kind>
97102
struct LstmCellForwardFunctor {
98103
void operator()(sycl::nd_item<1> item) const {
99104
bool has_bias = bias1_.data != nullptr;
@@ -205,7 +210,11 @@ struct LstmCellForwardFunctor {
205210
index_type totalElements_;
206211
};
207212

208-
template <typename scalar_t, typename accscalar_t, typename index_type>
213+
template <
214+
typename scalar_t,
215+
typename accscalar_t,
216+
typename index_type,
217+
int indexing_kind>
209218
struct LstmCellBackwardFunctor {
210219
void operator()(sycl::nd_item<1> item) const {
211220
bool has_gradoutput = gradoutput_.data != nullptr;
@@ -296,7 +305,11 @@ struct LstmCellBackwardFunctor {
296305
index_type totalElements_;
297306
};
298307

299-
template <typename scalar_t, typename accscalar_t, typename index_type>
308+
template <
309+
typename scalar_t,
310+
typename accscalar_t,
311+
typename index_type,
312+
int indexing_kind>
300313
struct GruCellForwardFunctor {
301314
void operator()(sycl::nd_item<1> item) const {
302315
bool has_bias = Bias1_.data != nullptr;
@@ -387,7 +400,11 @@ struct GruCellForwardFunctor {
387400
const index_type totalElements_;
388401
};
389402

390-
template <typename scalar_t, typename accscalar_t, typename index_type>
403+
template <
404+
typename scalar_t,
405+
typename accscalar_t,
406+
typename index_type,
407+
int indexing_kind>
391408
struct GruCellBackwardFunctor {
392409
void operator()(sycl::nd_item<1> item) const {
393410
for (index_type linearIndex = item.get_global_id(0);
@@ -469,12 +486,6 @@ void lstm_forward_impl(
469486
if (numel == 0)
470487
return;
471488

472-
using KernelT = LstmCellForwardFunctor<scalar_t, accscalar_t, index_type>;
473-
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
474-
auto config = rnn_get_launch_config(max_wg_size, numel);
475-
auto nwg = std::get<0>(config);
476-
auto local_range = std::get<1>(config);
477-
478489
auto input_gatesI = getTensorInfo<scalar_t, index_type>(input_gates);
479490
auto hidden_gatesI = getTensorInfo<scalar_t, index_type>(hidden_gates);
480491
auto input_biasI = tryGetTensorInfo<scalar_t, index_type>(input_bias);
@@ -503,6 +514,12 @@ void lstm_forward_impl(
503514
hyI,
504515
cyI,
505516
workspaceI);
517+
using KernelT =
518+
LstmCellForwardFunctor<scalar_t, accscalar_t, index_type, 1>;
519+
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
520+
auto config = rnn_get_launch_config(max_wg_size, numel);
521+
auto nwg = std::get<0>(config);
522+
auto local_range = std::get<1>(config);
506523
KernelT kfn(
507524
input_gatesI,
508525
hidden_gatesI,
@@ -517,6 +534,12 @@ void lstm_forward_impl(
517534
sycl_kernel_submit(
518535
nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
519536
} else {
537+
using KernelT =
538+
LstmCellForwardFunctor<scalar_t, accscalar_t, index_type, 2>;
539+
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
540+
auto config = rnn_get_launch_config(max_wg_size, numel);
541+
auto nwg = std::get<0>(config);
542+
auto local_range = std::get<1>(config);
520543
KernelT kfn(
521544
input_gatesI,
522545
hidden_gatesI,
@@ -548,12 +571,6 @@ void lstm_backward_impl(
548571
if (numel == 0)
549572
return;
550573

551-
using KernelT = LstmCellBackwardFunctor<scalar_t, accscalar_t, index_type>;
552-
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
553-
auto config = rnn_get_launch_config(max_wg_size, numel);
554-
auto nwg = std::get<0>(config);
555-
auto local_range = std::get<1>(config);
556-
557574
auto grad_hyI = tryGetTensorInfo<scalar_t, index_type>(grad_hy);
558575
auto grad_cyI = tryGetTensorInfo<scalar_t, index_type>(grad_cy);
559576
auto cxI = getTensorInfo<scalar_t, index_type>(cx);
@@ -567,6 +584,12 @@ void lstm_backward_impl(
567584
{grad_hy, grad_cy, cx, cy, workspace, grad_gates, grad_cx})) {
568585
collapseDims(
569586
grad_hyI, grad_cyI, cxI, cyI, workspaceI, grad_gatesI, grad_cxI);
587+
using KernelT =
588+
LstmCellBackwardFunctor<scalar_t, accscalar_t, index_type, 1>;
589+
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
590+
auto config = rnn_get_launch_config(max_wg_size, numel);
591+
auto nwg = std::get<0>(config);
592+
auto local_range = std::get<1>(config);
570593
KernelT kfn(
571594
workspaceI,
572595
grad_gatesI,
@@ -580,6 +603,12 @@ void lstm_backward_impl(
580603
sycl_kernel_submit(
581604
nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
582605
} else {
606+
using KernelT =
607+
LstmCellBackwardFunctor<scalar_t, accscalar_t, index_type, 2>;
608+
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
609+
auto config = rnn_get_launch_config(max_wg_size, numel);
610+
auto nwg = std::get<0>(config);
611+
auto local_range = std::get<1>(config);
583612
KernelT kfn(
584613
workspaceI,
585614
grad_gatesI,
@@ -610,12 +639,6 @@ void gru_forward_impl(
610639
if (numel == 0)
611640
return;
612641

613-
using KernelT = GruCellForwardFunctor<scalar_t, accscalar_t, index_type>;
614-
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
615-
auto config = rnn_get_launch_config(max_wg_size, numel);
616-
auto nwg = std::get<0>(config);
617-
auto local_range = std::get<1>(config);
618-
619642
auto input_gatesI = getTensorInfo<scalar_t, index_type>(input_gates);
620643
auto hidden_gatesI = getTensorInfo<scalar_t, index_type>(hidden_gates);
621644
auto input_biasI = tryGetTensorInfo<scalar_t, index_type>(input_bias);
@@ -641,6 +664,11 @@ void gru_forward_impl(
641664
hxI,
642665
hyI,
643666
workspaceI);
667+
using KernelT = GruCellForwardFunctor<scalar_t, accscalar_t, index_type, 1>;
668+
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
669+
auto config = rnn_get_launch_config(max_wg_size, numel);
670+
auto nwg = std::get<0>(config);
671+
auto local_range = std::get<1>(config);
644672
KernelT kfn(
645673
input_gatesI,
646674
hidden_gatesI,
@@ -654,6 +682,11 @@ void gru_forward_impl(
654682
sycl_kernel_submit(
655683
nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
656684
} else {
685+
using KernelT = GruCellForwardFunctor<scalar_t, accscalar_t, index_type, 2>;
686+
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
687+
auto config = rnn_get_launch_config(max_wg_size, numel);
688+
auto nwg = std::get<0>(config);
689+
auto local_range = std::get<1>(config);
657690
KernelT kfn(
658691
input_gatesI,
659692
hidden_gatesI,
@@ -682,12 +715,6 @@ void gru_backward_impl(
682715
if (numel == 0)
683716
return;
684717

685-
using KernelT = GruCellBackwardFunctor<scalar_t, accscalar_t, index_type>;
686-
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
687-
auto config = rnn_get_launch_config(max_wg_size, numel);
688-
auto nwg = std::get<0>(config);
689-
auto local_range = std::get<1>(config);
690-
691718
auto grad_hyI = getTensorInfo<scalar_t, index_type>(grad_hy);
692719
auto workspaceI = getTensorInfo<scalar_t, index_type>(workspace);
693720
auto grad_input_gatesI =
@@ -701,6 +728,12 @@ void gru_backward_impl(
701728
{grad_hy, workspace, grad_input_gates, grad_hidden_gates, grad_hx})) {
702729
collapseDims(
703730
grad_hyI, workspaceI, grad_input_gatesI, grad_hidden_gatesI, grad_hxI);
731+
using KernelT =
732+
GruCellBackwardFunctor<scalar_t, accscalar_t, index_type, 1>;
733+
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
734+
auto config = rnn_get_launch_config(max_wg_size, numel);
735+
auto nwg = std::get<0>(config);
736+
auto local_range = std::get<1>(config);
704737
KernelT kfn(
705738
grad_input_gatesI,
706739
grad_hidden_gatesI,
@@ -712,6 +745,12 @@ void gru_backward_impl(
712745
sycl_kernel_submit(
713746
nwg * local_range, local_range, getCurrentSYCLQueue(), kfn);
714747
} else {
748+
using KernelT =
749+
GruCellBackwardFunctor<scalar_t, accscalar_t, index_type, 2>;
750+
auto max_wg_size = syclMaxWorkGroupSize<KernelT>();
751+
auto config = rnn_get_launch_config(max_wg_size, numel);
752+
auto nwg = std::get<0>(config);
753+
auto local_range = std::get<1>(config);
715754
KernelT kfn(
716755
grad_input_gatesI,
717756
grad_hidden_gatesI,

0 commit comments

Comments
 (0)