@@ -54,10 +54,17 @@ template <typename T1, typename T2> T1 ceiling_quotient(T1 n, T2 m)
54
54
return ceiling_quotient<T1>(n, static_cast <T1>(m));
55
55
}
56
56
57
- template <typename inputT, typename outputT, typename IndexerT, size_t n_wi>
57
+ template <typename inputT,
58
+ typename outputT,
59
+ size_t n_wi,
60
+ typename IndexerT,
61
+ typename TransformerT>
58
62
class inclusive_scan_rec_local_scan_krn ;
59
63
60
- template <typename inputT, typename outputT, typename IndexerT>
64
+ template <typename inputT,
65
+ typename outputT,
66
+ typename IndexerT,
67
+ typename TransformerT>
61
68
class inclusive_scan_rec_chunk_update_krn ;
62
69
63
70
struct NoOpIndexer
@@ -136,29 +143,40 @@ struct Strided1DCyclicIndexer
136
143
py::ssize_t step = 1 ;
137
144
};
138
145
139
- template <typename _IndexerFn > struct ZeroChecker
146
+ template <typename inputT, typename outputT > struct NonZeroIndicator
140
147
{
148
+ NonZeroIndicator () {}
141
149
142
- ZeroChecker (_IndexerFn _indexer) : indexer_fn(_indexer) {}
143
-
144
- template <typename dataT>
145
- bool operator ()(dataT const *data, size_t gid) const
150
+ outputT operator ()(const inputT &val) const
146
151
{
147
- constexpr dataT _zero (0 );
152
+ constexpr outputT out_one (1 );
153
+ constexpr outputT out_zero (0 );
154
+ constexpr inputT val_zero (0 );
148
155
149
- return data[ indexer_fn (gid)] == _zero ;
156
+ return (val == val_zero) ? out_zero : out_one ;
150
157
}
158
+ };
151
159
152
- private:
153
- _IndexerFn indexer_fn;
160
+ template <typename T> struct NoOpTransformer
161
+ {
162
+ NoOpTransformer () {}
163
+
164
+ T operator ()(const T &val) const
165
+ {
166
+ return val;
167
+ }
154
168
};
155
169
156
170
/*
157
171
* for integer type maskT,
158
172
* output[j] = sum( input[s0 + i * s1], 0 <= i <= j)
159
173
* for 0 <= j < n_elems
160
174
*/
161
- template <typename inputT, typename outputT, typename IndexerT, size_t n_wi>
175
+ template <typename inputT,
176
+ typename outputT,
177
+ size_t n_wi,
178
+ typename IndexerT,
179
+ typename TransformerT>
162
180
sycl::event inclusive_scan_rec (sycl::queue exec_q,
163
181
size_t n_elems,
164
182
size_t wg_size,
@@ -167,6 +185,7 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
167
185
size_t s0,
168
186
size_t s1,
169
187
IndexerT indexer,
188
+ TransformerT transformer,
170
189
std::vector<sycl::event> const &depends = {})
171
190
{
172
191
size_t n_groups = ceiling_quotient (n_elems, n_wi * wg_size);
@@ -181,9 +200,7 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
181
200
182
201
slmT slm_iscan_tmp (lws, cgh);
183
202
184
- ZeroChecker<IndexerT> is_zero_fn (indexer);
185
-
186
- cgh.parallel_for <class inclusive_scan_rec_local_scan_krn <inputT, outputT, ZeroChecker<IndexerT>, n_wi>>(
203
+ cgh.parallel_for <class inclusive_scan_rec_local_scan_krn <inputT, outputT, n_wi, IndexerT, decltype (transformer)>>(
187
204
sycl::nd_range<1 >(gws, lws),
188
205
[=](sycl::nd_item<1 > it)
189
206
{
@@ -195,11 +212,10 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
195
212
size_t i = chunk_gid * n_wi;
196
213
for (size_t m_wi = 0 ; m_wi < n_wi; ++m_wi) {
197
214
constexpr outputT out_zero (0 );
198
- constexpr outputT out_one ( 1 );
215
+
199
216
local_isum[m_wi] =
200
217
(i + m_wi < n_elems)
201
- ? (is_zero_fn (input, s0 + s1 * (i + m_wi)) ? out_zero
202
- : out_one)
218
+ ? transformer (input[indexer (s0 + s1 * (i + m_wi))])
203
219
: out_zero;
204
220
}
205
221
@@ -240,14 +256,17 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
240
256
auto chunk_size = wg_size * n_wi;
241
257
242
258
NoOpIndexer _no_op_indexer{};
243
- auto e2 = inclusive_scan_rec<outputT, outputT, NoOpIndexer, n_wi>(
259
+ NoOpTransformer<outputT> _no_op_transformer{};
260
+ auto e2 = inclusive_scan_rec<outputT, outputT, n_wi, NoOpIndexer,
261
+ decltype (_no_op_transformer)>(
244
262
exec_q, n_groups - 1 , wg_size, output, temp, chunk_size - 1 ,
245
- chunk_size, _no_op_indexer, {inc_scan_phase1_ev});
263
+ chunk_size, _no_op_indexer, _no_op_transformer,
264
+ {inc_scan_phase1_ev});
246
265
247
266
// output[ chunk_size * (i + 1) + j] += temp[i]
248
267
auto e3 = exec_q.submit ([&](sycl::handler &cgh) {
249
268
cgh.depends_on (e2 );
250
- cgh.parallel_for <class inclusive_scan_rec_chunk_update_krn <inputT, outputT, IndexerT>>(
269
+ cgh.parallel_for <class inclusive_scan_rec_chunk_update_krn <inputT, outputT, IndexerT, decltype (transformer) >>(
251
270
{n_elems},
252
271
[=](auto wiid)
253
272
{
@@ -258,14 +277,13 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
258
277
);
259
278
});
260
279
261
- // dangling task to free the temporary
262
- exec_q.submit ([&](sycl::handler &cgh) {
280
+ sycl::event e4 = exec_q.submit ([&](sycl::handler &cgh) {
263
281
cgh.depends_on (e3 );
264
282
auto ctx = exec_q.get_context ();
265
283
cgh.host_task ([ctx, temp]() { sycl::free (temp, ctx); });
266
284
});
267
285
268
- out_event = e3 ;
286
+ out_event = e4 ;
269
287
}
270
288
271
289
return out_event;
@@ -502,10 +520,13 @@ size_t mask_positions_contig_impl(sycl::queue q,
502
520
size_t wg_size = 128 ;
503
521
504
522
NoOpIndexer flat_indexer{};
523
+ NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
505
524
506
- sycl::event comp_ev = inclusive_scan_rec<maskT, cumsumT, NoOpIndexer, n_wi>(
507
- q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0 , 1 , flat_indexer,
508
- depends);
525
+ sycl::event comp_ev =
526
+ inclusive_scan_rec<maskT, cumsumT, n_wi, decltype (flat_indexer),
527
+ decltype (non_zero_indicator)>(
528
+ q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0 , 1 ,
529
+ flat_indexer, non_zero_indicator, depends);
509
530
510
531
cumsumT *last_elem = cumsum_data_ptr + (n_elems - 1 );
511
532
@@ -558,11 +579,13 @@ size_t mask_positions_strided_impl(sycl::queue q,
558
579
size_t wg_size = 128 ;
559
580
560
581
StridedIndexer strided_indexer{nd, input_offset, shape_strides};
582
+ NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
561
583
562
584
sycl::event comp_ev =
563
- inclusive_scan_rec<maskT, cumsumT, StridedIndexer, n_wi>(
585
+ inclusive_scan_rec<maskT, cumsumT, n_wi, decltype (strided_indexer),
586
+ decltype (non_zero_indicator)>(
564
587
q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0 , 1 ,
565
- strided_indexer, depends);
588
+ strided_indexer, non_zero_indicator, depends);
566
589
567
590
cumsumT *last_elem = cumsum_data_ptr + (n_elems - 1 );
568
591
0 commit comments