Skip to content

Commit 7656259

Browse files
Merge pull request #1103 from IntelPython/hotfix/incorrect-mask-positions
Hotfix/incorrect mask positions
2 parents e779c6a + 3eed4e5 commit 7656259

File tree

2 files changed

+61
-29
lines changed

2 files changed

+61
-29
lines changed

dpctl/tensor/libtensor/include/kernels/boolean_advanced_indexing.hpp

Lines changed: 52 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,17 @@ template <typename T1, typename T2> T1 ceiling_quotient(T1 n, T2 m)
5454
return ceiling_quotient<T1>(n, static_cast<T1>(m));
5555
}
5656

57-
template <typename inputT, typename outputT, typename IndexerT, size_t n_wi>
57+
template <typename inputT,
58+
typename outputT,
59+
size_t n_wi,
60+
typename IndexerT,
61+
typename TransformerT>
5862
class inclusive_scan_rec_local_scan_krn;
5963

60-
template <typename inputT, typename outputT, typename IndexerT>
64+
template <typename inputT,
65+
typename outputT,
66+
typename IndexerT,
67+
typename TransformerT>
6168
class inclusive_scan_rec_chunk_update_krn;
6269

6370
struct NoOpIndexer
@@ -136,29 +143,40 @@ struct Strided1DCyclicIndexer
136143
py::ssize_t step = 1;
137144
};
138145

139-
template <typename _IndexerFn> struct ZeroChecker
146+
template <typename inputT, typename outputT> struct NonZeroIndicator
140147
{
148+
NonZeroIndicator() {}
141149

142-
ZeroChecker(_IndexerFn _indexer) : indexer_fn(_indexer) {}
143-
144-
template <typename dataT>
145-
bool operator()(dataT const *data, size_t gid) const
150+
outputT operator()(const inputT &val) const
146151
{
147-
constexpr dataT _zero(0);
152+
constexpr outputT out_one(1);
153+
constexpr outputT out_zero(0);
154+
constexpr inputT val_zero(0);
148155

149-
return data[indexer_fn(gid)] == _zero;
156+
return (val == val_zero) ? out_zero : out_one;
150157
}
158+
};
151159

152-
private:
153-
_IndexerFn indexer_fn;
160+
template <typename T> struct NoOpTransformer
161+
{
162+
NoOpTransformer() {}
163+
164+
T operator()(const T &val) const
165+
{
166+
return val;
167+
}
154168
};
155169

156170
/*
157171
* for integer type maskT,
158172
* output[j] = sum( input[s0 + i * s1], 0 <= i <= j)
159173
* for 0 <= j < n_elems
160174
*/
161-
template <typename inputT, typename outputT, typename IndexerT, size_t n_wi>
175+
template <typename inputT,
176+
typename outputT,
177+
size_t n_wi,
178+
typename IndexerT,
179+
typename TransformerT>
162180
sycl::event inclusive_scan_rec(sycl::queue exec_q,
163181
size_t n_elems,
164182
size_t wg_size,
@@ -167,6 +185,7 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
167185
size_t s0,
168186
size_t s1,
169187
IndexerT indexer,
188+
TransformerT transformer,
170189
std::vector<sycl::event> const &depends = {})
171190
{
172191
size_t n_groups = ceiling_quotient(n_elems, n_wi * wg_size);
@@ -181,9 +200,7 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
181200

182201
slmT slm_iscan_tmp(lws, cgh);
183202

184-
ZeroChecker<IndexerT> is_zero_fn(indexer);
185-
186-
cgh.parallel_for<class inclusive_scan_rec_local_scan_krn<inputT, outputT, ZeroChecker<IndexerT>, n_wi>>(
203+
cgh.parallel_for<class inclusive_scan_rec_local_scan_krn<inputT, outputT, n_wi, IndexerT, decltype(transformer)>>(
187204
sycl::nd_range<1>(gws, lws),
188205
[=](sycl::nd_item<1> it)
189206
{
@@ -195,11 +212,10 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
195212
size_t i = chunk_gid * n_wi;
196213
for (size_t m_wi = 0; m_wi < n_wi; ++m_wi) {
197214
constexpr outputT out_zero(0);
198-
constexpr outputT out_one(1);
215+
199216
local_isum[m_wi] =
200217
(i + m_wi < n_elems)
201-
? (is_zero_fn(input, s0 + s1 * (i + m_wi)) ? out_zero
202-
: out_one)
218+
? transformer(input[indexer(s0 + s1 * (i + m_wi))])
203219
: out_zero;
204220
}
205221

@@ -240,14 +256,17 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
240256
auto chunk_size = wg_size * n_wi;
241257

242258
NoOpIndexer _no_op_indexer{};
243-
auto e2 = inclusive_scan_rec<outputT, outputT, NoOpIndexer, n_wi>(
259+
NoOpTransformer<outputT> _no_op_transformer{};
260+
auto e2 = inclusive_scan_rec<outputT, outputT, n_wi, NoOpIndexer,
261+
decltype(_no_op_transformer)>(
244262
exec_q, n_groups - 1, wg_size, output, temp, chunk_size - 1,
245-
chunk_size, _no_op_indexer, {inc_scan_phase1_ev});
263+
chunk_size, _no_op_indexer, _no_op_transformer,
264+
{inc_scan_phase1_ev});
246265

247266
// output[ chunk_size * (i + 1) + j] += temp[i]
248267
auto e3 = exec_q.submit([&](sycl::handler &cgh) {
249268
cgh.depends_on(e2);
250-
cgh.parallel_for<class inclusive_scan_rec_chunk_update_krn<inputT, outputT, IndexerT>>(
269+
cgh.parallel_for<class inclusive_scan_rec_chunk_update_krn<inputT, outputT, IndexerT, decltype(transformer)>>(
251270
{n_elems},
252271
[=](auto wiid)
253272
{
@@ -258,14 +277,13 @@ sycl::event inclusive_scan_rec(sycl::queue exec_q,
258277
);
259278
});
260279

261-
// dangling task to free the temporary
262-
exec_q.submit([&](sycl::handler &cgh) {
280+
sycl::event e4 = exec_q.submit([&](sycl::handler &cgh) {
263281
cgh.depends_on(e3);
264282
auto ctx = exec_q.get_context();
265283
cgh.host_task([ctx, temp]() { sycl::free(temp, ctx); });
266284
});
267285

268-
out_event = e3;
286+
out_event = e4;
269287
}
270288

271289
return out_event;
@@ -502,10 +520,13 @@ size_t mask_positions_contig_impl(sycl::queue q,
502520
size_t wg_size = 128;
503521

504522
NoOpIndexer flat_indexer{};
523+
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
505524

506-
sycl::event comp_ev = inclusive_scan_rec<maskT, cumsumT, NoOpIndexer, n_wi>(
507-
q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0, 1, flat_indexer,
508-
depends);
525+
sycl::event comp_ev =
526+
inclusive_scan_rec<maskT, cumsumT, n_wi, decltype(flat_indexer),
527+
decltype(non_zero_indicator)>(
528+
q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0, 1,
529+
flat_indexer, non_zero_indicator, depends);
509530

510531
cumsumT *last_elem = cumsum_data_ptr + (n_elems - 1);
511532

@@ -558,11 +579,13 @@ size_t mask_positions_strided_impl(sycl::queue q,
558579
size_t wg_size = 128;
559580

560581
StridedIndexer strided_indexer{nd, input_offset, shape_strides};
582+
NonZeroIndicator<maskT, cumsumT> non_zero_indicator{};
561583

562584
sycl::event comp_ev =
563-
inclusive_scan_rec<maskT, cumsumT, StridedIndexer, n_wi>(
585+
inclusive_scan_rec<maskT, cumsumT, n_wi, decltype(strided_indexer),
586+
decltype(non_zero_indicator)>(
564587
q, n_elems, wg_size, mask_data_ptr, cumsum_data_ptr, 0, 1,
565-
strided_indexer, depends);
588+
strided_indexer, non_zero_indicator, depends);
566589

567590
cumsumT *last_elem = cumsum_data_ptr + (n_elems - 1);
568591

dpctl/tests/test_usm_ndarray_indexing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1218,3 +1218,12 @@ def test_assign_scalar():
12181218
x[dpt.nonzero(cond)] = -1
12191219
expected = np.array([-1, -1, -1, -1, -1, 0, 1, 2, 3, 4], dtype=x.dtype)
12201220
assert (dpt.asnumpy(x) == expected).all()
1221+
1222+
1223+
def test_nonzero_large():
1224+
get_queue_or_skip()
1225+
m = dpt.full((60, 80), True)
1226+
assert m[m].size == m.size
1227+
1228+
m = dpt.full((30, 60, 80), True)
1229+
assert m[m].size == m.size

0 commit comments

Comments
 (0)