Skip to content

Commit 6e98feb

Browse files
Implementation of dpctl.tensor.greater function
1 parent fbbf1b3 commit 6e98feb

File tree

5 files changed

+672
-3
lines changed

5 files changed

+672
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@
101101
exp,
102102
expm1,
103103
floor_divide,
104+
greater,
104105
imag,
105106
isfinite,
106107
isinf,
@@ -199,6 +200,7 @@
199200
"cos",
200201
"exp",
201202
"expm1",
203+
"greater",
202204
"imag",
203205
"isinf",
204206
"isnan",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,31 @@
297297
)
298298

299299
# B11: ==== GREATER (x1, x2)
300-
# FIXME: implement B11
300+
_greater_docstring_ = """
301+
greater(x1, x2, out=None, order='K')
302+
Computes the greater-than test results for each element `x1_i` of
303+
the input array `x1` the respective element `x2_i` of the input array `x2`.
304+
Args:
305+
x1 (usm_ndarray):
306+
First input array, expected to have numeric data type.
307+
x2 (usm_ndarray):
308+
Second input array, also expected to have numeric data type.
309+
out ({None, usm_ndarray}, optional):
310+
Output array to populate.
311+
Array have the correct shape and the expected data type.
312+
order ("C","F","A","K", optional):
313+
Memory layout of the newly output array, if parameter `out` is `None`.
314+
Default: "K".
315+
Returns:
316+
usm_narray:
317+
An array containing the result of element-wise greater-than comparison.
318+
The data type of the returned array is determined by the
319+
Type Promotion Rules.
320+
"""
321+
322+
greater = BinaryElementwiseFunc(
323+
"greater", ti._greater_result_type, ti._greater, _greater_docstring_
324+
)
301325

302326
# B12: ==== GREATER_EQUAL (x1, x2)
303327
# FIXME: implement B12
Lines changed: 319 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,319 @@
1+
//=== greater.hpp - Binary function GREATER ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain in1 copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of comparison of
24+
/// tensor elements.
25+
//===---------------------------------------------------------------------===//
26+
27+
#pragma once
28+
#include <CL/sycl.hpp>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "utils/offset_utils.hpp"
34+
#include "utils/type_dispatch.hpp"
35+
#include "utils/type_utils.hpp"
36+
37+
#include "kernels/elementwise_functions/common.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace greater
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace tu_ns = dpctl::tensor::type_utils;
52+
53+
template <typename argT1, typename argT2, typename resT> struct GreaterFunctor
54+
{
55+
static_assert(std::is_same_v<resT, bool>);
56+
57+
using supports_sg_loadstore = std::negation<
58+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
59+
using supports_vec = std::conjunction<
60+
std::is_same<argT1, argT2>,
61+
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
62+
tu_ns::is_complex<argT2>>>>;
63+
64+
resT operator()(const argT1 &in1, const argT2 &in2)
65+
{
66+
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
67+
std::is_same_v<argT2, float>)
68+
{
69+
float real1 = std::real(in1);
70+
return (real1 == in2) ? (std::imag(in1) > 0.0f) : real1 > in2;
71+
}
72+
else if constexpr (std::is_same_v<argT1, float> &&
73+
std::is_same_v<argT2, std::complex<float>>)
74+
{
75+
float real2 = std::real(in2);
76+
return (in1 == real2) ? (0.0f > std::imag(in2)) : in1 > real2;
77+
}
78+
else if constexpr (tu_ns::is_complex<argT1>::value ||
79+
tu_ns::is_complex<argT2>::value)
80+
{
81+
static_assert(std::is_same_v<argT1, argT2>);
82+
using realT = typename argT1::value_type;
83+
realT real1 = std::real(in1);
84+
realT real2 = std::real(in2);
85+
86+
return (real1 == real2) ? (std::imag(in1) > std::imag(in2))
87+
: real1 > real2;
88+
}
89+
else {
90+
return (in1 > in2);
91+
}
92+
}
93+
94+
template <int vec_sz>
95+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
96+
const sycl::vec<argT2, vec_sz> &in2)
97+
{
98+
99+
auto tmp = (in1 > in2);
100+
101+
if constexpr (std::is_same_v<resT,
102+
typename decltype(tmp)::element_type>) {
103+
return tmp;
104+
}
105+
else {
106+
using dpctl::tensor::type_utils::vec_cast;
107+
108+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
109+
tmp);
110+
}
111+
}
112+
};
113+
114+
template <typename argT1,
115+
typename argT2,
116+
typename resT,
117+
unsigned int vec_sz = 4,
118+
unsigned int n_vecs = 2>
119+
using GreaterContigFunctor =
120+
elementwise_common::BinaryContigFunctor<argT1,
121+
argT2,
122+
resT,
123+
GreaterFunctor<argT1, argT2, resT>,
124+
vec_sz,
125+
n_vecs>;
126+
127+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
128+
using GreaterStridedFunctor = elementwise_common::BinaryStridedFunctor<
129+
argT1,
130+
argT2,
131+
resT,
132+
IndexerT,
133+
GreaterFunctor<argT1, argT2, resT>>;
134+
135+
template <typename T1, typename T2> struct GreaterOutputType
136+
{
137+
using value_type = typename std::disjunction< // disjunction is C++17
138+
// feature, supported by DPC++
139+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
140+
td_ns::
141+
BinaryTypeMapResultEntry<T1, std::uint8_t, T2, std::uint8_t, bool>,
142+
td_ns::BinaryTypeMapResultEntry<T1, std::int8_t, T2, std::int8_t, bool>,
143+
td_ns::BinaryTypeMapResultEntry<T1,
144+
std::uint16_t,
145+
T2,
146+
std::uint16_t,
147+
bool>,
148+
td_ns::
149+
BinaryTypeMapResultEntry<T1, std::int16_t, T2, std::int16_t, bool>,
150+
td_ns::BinaryTypeMapResultEntry<T1,
151+
std::uint32_t,
152+
T2,
153+
std::uint32_t,
154+
bool>,
155+
td_ns::
156+
BinaryTypeMapResultEntry<T1, std::int32_t, T2, std::int32_t, bool>,
157+
td_ns::BinaryTypeMapResultEntry<T1,
158+
std::uint64_t,
159+
T2,
160+
std::uint64_t,
161+
bool>,
162+
td_ns::
163+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
164+
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
165+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
166+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
167+
td_ns::BinaryTypeMapResultEntry<T1,
168+
std::complex<float>,
169+
T2,
170+
std::complex<float>,
171+
bool>,
172+
td_ns::BinaryTypeMapResultEntry<T1,
173+
std::complex<double>,
174+
T2,
175+
std::complex<double>,
176+
bool>,
177+
td_ns::
178+
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
179+
td_ns::
180+
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
181+
td_ns::DefaultResultEntry<void>>::result_type;
182+
};
183+
184+
template <typename argT1,
185+
typename argT2,
186+
typename resT,
187+
unsigned int vec_sz,
188+
unsigned int n_vecs>
189+
class greater_contig_kernel;
190+
191+
template <typename argTy1, typename argTy2>
192+
sycl::event greater_contig_impl(sycl::queue exec_q,
193+
size_t nelems,
194+
const char *arg1_p,
195+
py::ssize_t arg1_offset,
196+
const char *arg2_p,
197+
py::ssize_t arg2_offset,
198+
char *res_p,
199+
py::ssize_t res_offset,
200+
const std::vector<sycl::event> &depends = {})
201+
{
202+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
203+
cgh.depends_on(depends);
204+
205+
size_t lws = 64;
206+
constexpr unsigned int vec_sz = 4;
207+
constexpr unsigned int n_vecs = 2;
208+
const size_t n_groups =
209+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
210+
const auto gws_range = sycl::range<1>(n_groups * lws);
211+
const auto lws_range = sycl::range<1>(lws);
212+
213+
using resTy = typename GreaterOutputType<argTy1, argTy2>::value_type;
214+
215+
const argTy1 *arg1_tp =
216+
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
217+
const argTy2 *arg2_tp =
218+
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
219+
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
220+
221+
cgh.parallel_for<
222+
greater_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
223+
sycl::nd_range<1>(gws_range, lws_range),
224+
GreaterContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
225+
arg1_tp, arg2_tp, res_tp, nelems));
226+
});
227+
return comp_ev;
228+
}
229+
230+
template <typename fnT, typename T1, typename T2> struct GreaterContigFactory
231+
{
232+
fnT get()
233+
{
234+
if constexpr (std::is_same_v<
235+
typename GreaterOutputType<T1, T2>::value_type, void>)
236+
{
237+
fnT fn = nullptr;
238+
return fn;
239+
}
240+
else {
241+
fnT fn = greater_contig_impl<T1, T2>;
242+
return fn;
243+
}
244+
}
245+
};
246+
247+
template <typename fnT, typename T1, typename T2> struct GreaterTypeMapFactory
248+
{
249+
/*! @brief get typeid for output type of operator()>(x, y), always bool */
250+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
251+
{
252+
using rT = typename GreaterOutputType<T1, T2>::value_type;
253+
return td_ns::GetTypeid<rT>{}.get();
254+
}
255+
};
256+
257+
template <typename T1, typename T2, typename resT, typename IndexerT>
258+
class greater_strided_strided_kernel;
259+
260+
template <typename argTy1, typename argTy2>
261+
sycl::event
262+
greater_strided_impl(sycl::queue exec_q,
263+
size_t nelems,
264+
int nd,
265+
const py::ssize_t *shape_and_strides,
266+
const char *arg1_p,
267+
py::ssize_t arg1_offset,
268+
const char *arg2_p,
269+
py::ssize_t arg2_offset,
270+
char *res_p,
271+
py::ssize_t res_offset,
272+
const std::vector<sycl::event> &depends,
273+
const std::vector<sycl::event> &additional_depends)
274+
{
275+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
276+
cgh.depends_on(depends);
277+
cgh.depends_on(additional_depends);
278+
279+
using resTy = typename GreaterOutputType<argTy1, argTy2>::value_type;
280+
281+
using IndexerT =
282+
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
283+
284+
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
285+
shape_and_strides};
286+
287+
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
288+
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
289+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
290+
291+
cgh.parallel_for<
292+
greater_strided_strided_kernel<argTy1, argTy2, resTy, IndexerT>>(
293+
{nelems}, GreaterStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
294+
arg1_tp, arg2_tp, res_tp, indexer));
295+
});
296+
return comp_ev;
297+
}
298+
299+
template <typename fnT, typename T1, typename T2> struct GreaterStridedFactory
300+
{
301+
fnT get()
302+
{
303+
if constexpr (std::is_same_v<
304+
typename GreaterOutputType<T1, T2>::value_type, void>)
305+
{
306+
fnT fn = nullptr;
307+
return fn;
308+
}
309+
else {
310+
fnT fn = greater_strided_impl<T1, T2>;
311+
return fn;
312+
}
313+
}
314+
};
315+
316+
} // namespace greater
317+
} // namespace kernels
318+
} // namespace tensor
319+
} // namespace dpctl

0 commit comments

Comments
 (0)