Skip to content

Commit 0185fa9

Browse files
Impementation of dpctl.tensor.less function (#1235)
* Impementation of dpctl.tensor.less function * Add tests for dpctl.tensor.less * Replace branching with ternary operator * Fix remarks and extend support for complex and float * Update tests for less function
1 parent f35b34d commit 0185fa9

File tree

5 files changed

+670
-3
lines changed

5 files changed

+670
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@
105105
isfinite,
106106
isinf,
107107
isnan,
108+
less,
108109
log,
109110
log1p,
110111
multiply,
@@ -202,6 +203,7 @@
202203
"isinf",
203204
"isnan",
204205
"isfinite",
206+
"less",
205207
"log",
206208
"log1p",
207209
"proj",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,33 @@
405405
)
406406

407407
# B13: ==== LESS (x1, x2)
408-
# FIXME: implement B13
408+
_less_docstring_ = """
409+
less(x1, x2, out=None, order='K')
410+
411+
Computes the less-than test results for each element `x1_i` of
412+
the input array `x1` the respective element `x2_i` of the input array `x2`.
413+
414+
Args:
415+
x1 (usm_ndarray):
416+
First input array, expected to have numeric data type.
417+
x2 (usm_ndarray):
418+
Second input array, also expected to have numeric data type.
419+
out ({None, usm_ndarray}, optional):
420+
Output array to populate.
421+
Array have the correct shape and the expected data type.
422+
order ("C","F","A","K", optional):
423+
Memory layout of the newly output array, if parameter `out` is `None`.
424+
Default: "K".
425+
Returns:
426+
usm_narray:
427+
An array containing the result of element-wise less-than comparison.
428+
The data type of the returned array is determined by the
429+
Type Promotion Rules.
430+
"""
431+
432+
less = BinaryElementwiseFunc(
433+
"less", ti._less_result_type, ti._less, _less_docstring_
434+
)
409435

410436
# B14: ==== LESS_EQUAL (x1, x2)
411437
# FIXME: implement B14
Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
//=== less.hpp - Binary function LESS ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain in1 copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of comparison of
23+
/// tensor elements.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "utils/offset_utils.hpp"
33+
#include "utils/type_dispatch.hpp"
34+
#include "utils/type_utils.hpp"
35+
36+
#include "kernels/elementwise_functions/common.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace less
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
52+
template <typename argT1, typename argT2, typename resT> struct LessFunctor
53+
{
54+
static_assert(std::is_same_v<resT, bool>);
55+
56+
using supports_sg_loadstore = std::negation<
57+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
58+
using supports_vec = std::conjunction<
59+
std::is_same<argT1, argT2>,
60+
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
61+
tu_ns::is_complex<argT2>>>>;
62+
63+
resT operator()(const argT1 &in1, const argT2 &in2)
64+
{
65+
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
66+
std::is_same_v<argT2, float>)
67+
{
68+
float real1 = std::real(in1);
69+
return (real1 == in2) ? (std::imag(in1) < 0.0f) : real1 < in2;
70+
}
71+
else if constexpr (std::is_same_v<argT1, float> &&
72+
std::is_same_v<argT2, std::complex<float>>)
73+
{
74+
float real2 = std::real(in2);
75+
return (in1 == real2) ? (0.0f < std::imag(in2)) : in1 < real2;
76+
}
77+
else if constexpr (tu_ns::is_complex<argT1>::value ||
78+
tu_ns::is_complex<argT2>::value)
79+
{
80+
static_assert(std::is_same_v<argT1, argT2>);
81+
using realT = typename argT1::value_type;
82+
realT real1 = std::real(in1);
83+
realT real2 = std::real(in2);
84+
85+
return (real1 == real2) ? (std::imag(in1) < std::imag(in2))
86+
: real1 < real2;
87+
}
88+
else {
89+
return (in1 < in2);
90+
}
91+
}
92+
93+
template <int vec_sz>
94+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
95+
const sycl::vec<argT2, vec_sz> &in2)
96+
{
97+
98+
auto tmp = (in1 < in2);
99+
100+
if constexpr (std::is_same_v<resT,
101+
typename decltype(tmp)::element_type>) {
102+
return tmp;
103+
}
104+
else {
105+
using dpctl::tensor::type_utils::vec_cast;
106+
107+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
108+
tmp);
109+
}
110+
}
111+
};
112+
113+
template <typename argT1,
114+
typename argT2,
115+
typename resT,
116+
unsigned int vec_sz = 4,
117+
unsigned int n_vecs = 2>
118+
using LessContigFunctor =
119+
elementwise_common::BinaryContigFunctor<argT1,
120+
argT2,
121+
resT,
122+
LessFunctor<argT1, argT2, resT>,
123+
vec_sz,
124+
n_vecs>;
125+
126+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
127+
using LessStridedFunctor =
128+
elementwise_common::BinaryStridedFunctor<argT1,
129+
argT2,
130+
resT,
131+
IndexerT,
132+
LessFunctor<argT1, argT2, resT>>;
133+
134+
template <typename T1, typename T2> struct LessOutputType
135+
{
136+
using value_type = typename std::disjunction< // disjunction is C++17
137+
// feature, supported by DPC++
138+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
139+
td_ns::
140+
BinaryTypeMapResultEntry<T1, std::uint8_t, T2, std::uint8_t, bool>,
141+
td_ns::BinaryTypeMapResultEntry<T1, std::int8_t, T2, std::int8_t, bool>,
142+
td_ns::BinaryTypeMapResultEntry<T1,
143+
std::uint16_t,
144+
T2,
145+
std::uint16_t,
146+
bool>,
147+
td_ns::
148+
BinaryTypeMapResultEntry<T1, std::int16_t, T2, std::int16_t, bool>,
149+
td_ns::BinaryTypeMapResultEntry<T1,
150+
std::uint32_t,
151+
T2,
152+
std::uint32_t,
153+
bool>,
154+
td_ns::
155+
BinaryTypeMapResultEntry<T1, std::int32_t, T2, std::int32_t, bool>,
156+
td_ns::BinaryTypeMapResultEntry<T1,
157+
std::uint64_t,
158+
T2,
159+
std::uint64_t,
160+
bool>,
161+
td_ns::
162+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
163+
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
164+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
165+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
166+
td_ns::BinaryTypeMapResultEntry<T1,
167+
std::complex<float>,
168+
T2,
169+
std::complex<float>,
170+
bool>,
171+
td_ns::BinaryTypeMapResultEntry<T1,
172+
std::complex<double>,
173+
T2,
174+
std::complex<double>,
175+
bool>,
176+
td_ns::
177+
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
178+
td_ns::
179+
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
180+
td_ns::DefaultResultEntry<void>>::result_type;
181+
};
182+
183+
template <typename argT1,
184+
typename argT2,
185+
typename resT,
186+
unsigned int vec_sz,
187+
unsigned int n_vecs>
188+
class less_contig_kernel;
189+
190+
template <typename argTy1, typename argTy2>
191+
sycl::event less_contig_impl(sycl::queue exec_q,
192+
size_t nelems,
193+
const char *arg1_p,
194+
py::ssize_t arg1_offset,
195+
const char *arg2_p,
196+
py::ssize_t arg2_offset,
197+
char *res_p,
198+
py::ssize_t res_offset,
199+
const std::vector<sycl::event> &depends = {})
200+
{
201+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
202+
cgh.depends_on(depends);
203+
204+
size_t lws = 64;
205+
constexpr unsigned int vec_sz = 4;
206+
constexpr unsigned int n_vecs = 2;
207+
const size_t n_groups =
208+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
209+
const auto gws_range = sycl::range<1>(n_groups * lws);
210+
const auto lws_range = sycl::range<1>(lws);
211+
212+
using resTy = typename LessOutputType<argTy1, argTy2>::value_type;
213+
214+
const argTy1 *arg1_tp =
215+
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
216+
const argTy2 *arg2_tp =
217+
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
218+
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
219+
220+
cgh.parallel_for<
221+
less_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
222+
sycl::nd_range<1>(gws_range, lws_range),
223+
LessContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
224+
arg1_tp, arg2_tp, res_tp, nelems));
225+
});
226+
return comp_ev;
227+
}
228+
229+
template <typename fnT, typename T1, typename T2> struct LessContigFactory
230+
{
231+
fnT get()
232+
{
233+
if constexpr (std::is_same_v<
234+
typename LessOutputType<T1, T2>::value_type, void>) {
235+
fnT fn = nullptr;
236+
return fn;
237+
}
238+
else {
239+
fnT fn = less_contig_impl<T1, T2>;
240+
return fn;
241+
}
242+
}
243+
};
244+
245+
template <typename fnT, typename T1, typename T2> struct LessTypeMapFactory
246+
{
247+
/*! @brief get typeid for output type of operator()>(x, y), always bool */
248+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
249+
{
250+
using rT = typename LessOutputType<T1, T2>::value_type;
251+
return td_ns::GetTypeid<rT>{}.get();
252+
}
253+
};
254+
255+
template <typename T1, typename T2, typename resT, typename IndexerT>
256+
class less_strided_strided_kernel;
257+
258+
template <typename argTy1, typename argTy2>
259+
sycl::event
260+
less_strided_impl(sycl::queue exec_q,
261+
size_t nelems,
262+
int nd,
263+
const py::ssize_t *shape_and_strides,
264+
const char *arg1_p,
265+
py::ssize_t arg1_offset,
266+
const char *arg2_p,
267+
py::ssize_t arg2_offset,
268+
char *res_p,
269+
py::ssize_t res_offset,
270+
const std::vector<sycl::event> &depends,
271+
const std::vector<sycl::event> &additional_depends)
272+
{
273+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
274+
cgh.depends_on(depends);
275+
cgh.depends_on(additional_depends);
276+
277+
using resTy = typename LessOutputType<argTy1, argTy2>::value_type;
278+
279+
using IndexerT =
280+
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
281+
282+
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
283+
shape_and_strides};
284+
285+
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
286+
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
287+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
288+
289+
cgh.parallel_for<
290+
less_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
291+
{nelems}, LessStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
292+
arg1_tp, arg2_tp, res_tp, indexer));
293+
});
294+
return comp_ev;
295+
}
296+
297+
template <typename fnT, typename T1, typename T2> struct LessStridedFactory
298+
{
299+
fnT get()
300+
{
301+
if constexpr (std::is_same_v<
302+
typename LessOutputType<T1, T2>::value_type, void>) {
303+
fnT fn = nullptr;
304+
return fn;
305+
}
306+
else {
307+
fnT fn = less_strided_impl<T1, T2>;
308+
return fn;
309+
}
310+
}
311+
};
312+
313+
} // namespace less
314+
} // namespace kernels
315+
} // namespace tensor
316+
} // namespace dpctl

0 commit comments

Comments
 (0)