Skip to content

Commit 66c4006

Browse files
Implementation of dpctl.tensor.greater_equal function
1 parent 6e98feb commit 66c4006

File tree

5 files changed

+688
-3
lines changed

5 files changed

+688
-3
lines changed

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@
102102
expm1,
103103
floor_divide,
104104
greater,
105+
greater_equal,
105106
imag,
106107
isfinite,
107108
isinf,
@@ -201,6 +202,7 @@
201202
"exp",
202203
"expm1",
203204
"greater",
205+
"greater_equal",
204206
"imag",
205207
"isinf",
206208
"isnan",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,35 @@
324324
)
325325

326326
# B12: ==== GREATER_EQUAL (x1, x2)
327-
# FIXME: implement B12
327+
_greater_equal_docstring_ = """
328+
greater_equal(x1, x2, out=None, order='K')
329+
Computes the greater-than or equal-to test results for each element `x1_i` of
330+
the input array `x1` the respective element `x2_i` of the input array `x2`.
331+
Args:
332+
x1 (usm_ndarray):
333+
First input array, expected to have numeric data type.
334+
x2 (usm_ndarray):
335+
Second input array, also expected to have numeric data type.
336+
out ({None, usm_ndarray}, optional):
337+
Output array to populate.
338+
Array have the correct shape and the expected data type.
339+
order ("C","F","A","K", optional):
340+
Memory layout of the newly output array, if parameter `out` is `None`.
341+
Default: "K".
342+
Returns:
343+
usm_narray:
344+
An array containing the result of element-wise greater-than or equal-to
345+
comparison.
346+
The data type of the returned array is determined by the
347+
Type Promotion Rules.
348+
"""
349+
350+
greater_equal = BinaryElementwiseFunc(
351+
"greater_equal",
352+
ti._greater_equal_result_type,
353+
ti._greater_equal,
354+
_greater_equal_docstring_,
355+
)
328356

329357
# U16: ==== IMAG (x)
330358
_imag_docstring = """
Lines changed: 329 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,329 @@
1+
//=== greater_equal.hpp - Binary function GREATER_EQUAL ------
2+
//*-C++-*--/===//
3+
//
4+
// Data Parallel Control (dpctl)
5+
//
6+
// Copyright 2020-2023 Intel Corporation
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain in1 copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
//
20+
//===---------------------------------------------------------------------===//
21+
///
22+
/// \file
23+
/// This file defines kernels for elementwise evaluation of comparison of
24+
/// tensor elements.
25+
//===---------------------------------------------------------------------===//
26+
27+
#pragma once
28+
#include <CL/sycl.hpp>
29+
#include <cstddef>
30+
#include <cstdint>
31+
#include <type_traits>
32+
33+
#include "utils/offset_utils.hpp"
34+
#include "utils/type_dispatch.hpp"
35+
#include "utils/type_utils.hpp"
36+
37+
#include "kernels/elementwise_functions/common.hpp"
38+
#include <pybind11/pybind11.h>
39+
40+
namespace dpctl
41+
{
42+
namespace tensor
43+
{
44+
namespace kernels
45+
{
46+
namespace greater_equal
47+
{
48+
49+
namespace py = pybind11;
50+
namespace td_ns = dpctl::tensor::type_dispatch;
51+
namespace tu_ns = dpctl::tensor::type_utils;
52+
53+
template <typename argT1, typename argT2, typename resT>
54+
struct GreaterEqualFunctor
55+
{
56+
static_assert(std::is_same_v<resT, bool>);
57+
58+
using supports_sg_loadstore = std::negation<
59+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
60+
using supports_vec = std::conjunction<
61+
std::is_same<argT1, argT2>,
62+
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
63+
tu_ns::is_complex<argT2>>>>;
64+
65+
resT operator()(const argT1 &in1, const argT2 &in2)
66+
{
67+
if constexpr (std::is_same_v<argT1, std::complex<float>> &&
68+
std::is_same_v<argT2, float>)
69+
{
70+
float real1 = std::real(in1);
71+
return (real1 == in2) ? (std::imag(in1) >= 0.0f) : real1 >= in2;
72+
}
73+
else if constexpr (std::is_same_v<argT1, float> &&
74+
std::is_same_v<argT2, std::complex<float>>)
75+
{
76+
float real2 = std::real(in2);
77+
return (in1 == real2) ? (0.0f >= std::imag(in2)) : in1 >= real2;
78+
}
79+
else if constexpr (tu_ns::is_complex<argT1>::value ||
80+
tu_ns::is_complex<argT2>::value)
81+
{
82+
static_assert(std::is_same_v<argT1, argT2>);
83+
using realT = typename argT1::value_type;
84+
realT real1 = std::real(in1);
85+
realT real2 = std::real(in2);
86+
87+
return (real1 == real2) ? (std::imag(in1) >= std::imag(in2))
88+
: real1 >= real2;
89+
}
90+
else {
91+
return (in1 >= in2);
92+
}
93+
}
94+
95+
template <int vec_sz>
96+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
97+
const sycl::vec<argT2, vec_sz> &in2)
98+
{
99+
100+
auto tmp = (in1 >= in2);
101+
102+
if constexpr (std::is_same_v<resT,
103+
typename decltype(tmp)::element_type>) {
104+
return tmp;
105+
}
106+
else {
107+
using dpctl::tensor::type_utils::vec_cast;
108+
109+
return vec_cast<resT, typename decltype(tmp)::element_type, vec_sz>(
110+
tmp);
111+
}
112+
}
113+
};
114+
115+
template <typename argT1,
116+
typename argT2,
117+
typename resT,
118+
unsigned int vec_sz = 4,
119+
unsigned int n_vecs = 2>
120+
using GreaterEqualContigFunctor = elementwise_common::BinaryContigFunctor<
121+
argT1,
122+
argT2,
123+
resT,
124+
GreaterEqualFunctor<argT1, argT2, resT>,
125+
vec_sz,
126+
n_vecs>;
127+
128+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
129+
using GreaterEqualStridedFunctor = elementwise_common::BinaryStridedFunctor<
130+
argT1,
131+
argT2,
132+
resT,
133+
IndexerT,
134+
GreaterEqualFunctor<argT1, argT2, resT>>;
135+
136+
template <typename T1, typename T2> struct GreaterEqualOutputType
137+
{
138+
using value_type = typename std::disjunction< // disjunction is C++17
139+
// feature, supported by DPC++
140+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
141+
td_ns::
142+
BinaryTypeMapResultEntry<T1, std::uint8_t, T2, std::uint8_t, bool>,
143+
td_ns::BinaryTypeMapResultEntry<T1, std::int8_t, T2, std::int8_t, bool>,
144+
td_ns::BinaryTypeMapResultEntry<T1,
145+
std::uint16_t,
146+
T2,
147+
std::uint16_t,
148+
bool>,
149+
td_ns::
150+
BinaryTypeMapResultEntry<T1, std::int16_t, T2, std::int16_t, bool>,
151+
td_ns::BinaryTypeMapResultEntry<T1,
152+
std::uint32_t,
153+
T2,
154+
std::uint32_t,
155+
bool>,
156+
td_ns::
157+
BinaryTypeMapResultEntry<T1, std::int32_t, T2, std::int32_t, bool>,
158+
td_ns::BinaryTypeMapResultEntry<T1,
159+
std::uint64_t,
160+
T2,
161+
std::uint64_t,
162+
bool>,
163+
td_ns::
164+
BinaryTypeMapResultEntry<T1, std::int64_t, T2, std::int64_t, bool>,
165+
td_ns::BinaryTypeMapResultEntry<T1, sycl::half, T2, sycl::half, bool>,
166+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, bool>,
167+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, bool>,
168+
td_ns::BinaryTypeMapResultEntry<T1,
169+
std::complex<float>,
170+
T2,
171+
std::complex<float>,
172+
bool>,
173+
td_ns::BinaryTypeMapResultEntry<T1,
174+
std::complex<double>,
175+
T2,
176+
std::complex<double>,
177+
bool>,
178+
td_ns::
179+
BinaryTypeMapResultEntry<T1, float, T2, std::complex<float>, bool>,
180+
td_ns::
181+
BinaryTypeMapResultEntry<T1, std::complex<float>, T2, float, bool>,
182+
td_ns::DefaultResultEntry<void>>::result_type;
183+
};
184+
185+
template <typename argT1,
186+
typename argT2,
187+
typename resT,
188+
unsigned int vec_sz,
189+
unsigned int n_vecs>
190+
class greater_equal_contig_kernel;
191+
192+
template <typename argTy1, typename argTy2>
193+
sycl::event
194+
greater_equal_contig_impl(sycl::queue exec_q,
195+
size_t nelems,
196+
const char *arg1_p,
197+
py::ssize_t arg1_offset,
198+
const char *arg2_p,
199+
py::ssize_t arg2_offset,
200+
char *res_p,
201+
py::ssize_t res_offset,
202+
const std::vector<sycl::event> &depends = {})
203+
{
204+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
205+
cgh.depends_on(depends);
206+
207+
size_t lws = 64;
208+
constexpr unsigned int vec_sz = 4;
209+
constexpr unsigned int n_vecs = 2;
210+
const size_t n_groups =
211+
((nelems + lws * n_vecs * vec_sz - 1) / (lws * n_vecs * vec_sz));
212+
const auto gws_range = sycl::range<1>(n_groups * lws);
213+
const auto lws_range = sycl::range<1>(lws);
214+
215+
using resTy =
216+
typename GreaterEqualOutputType<argTy1, argTy2>::value_type;
217+
218+
const argTy1 *arg1_tp =
219+
reinterpret_cast<const argTy1 *>(arg1_p) + arg1_offset;
220+
const argTy2 *arg2_tp =
221+
reinterpret_cast<const argTy2 *>(arg2_p) + arg2_offset;
222+
resTy *res_tp = reinterpret_cast<resTy *>(res_p) + res_offset;
223+
224+
cgh.parallel_for<
225+
greater_equal_contig_kernel<argTy1, argTy2, resTy, vec_sz, n_vecs>>(
226+
sycl::nd_range<1>(gws_range, lws_range),
227+
GreaterEqualContigFunctor<argTy1, argTy2, resTy, vec_sz, n_vecs>(
228+
arg1_tp, arg2_tp, res_tp, nelems));
229+
});
230+
return comp_ev;
231+
}
232+
233+
template <typename fnT, typename T1, typename T2>
234+
struct GreaterEqualContigFactory
235+
{
236+
fnT get()
237+
{
238+
if constexpr (std::is_same_v<
239+
typename GreaterEqualOutputType<T1, T2>::value_type,
240+
void>)
241+
{
242+
fnT fn = nullptr;
243+
return fn;
244+
}
245+
else {
246+
fnT fn = greater_equal_contig_impl<T1, T2>;
247+
return fn;
248+
}
249+
}
250+
};
251+
252+
template <typename fnT, typename T1, typename T2>
253+
struct GreaterEqualTypeMapFactory
254+
{
255+
/*! @brief get typeid for output type of operator()>(x, y), always bool */
256+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
257+
{
258+
using rT = typename GreaterEqualOutputType<T1, T2>::value_type;
259+
return td_ns::GetTypeid<rT>{}.get();
260+
}
261+
};
262+
263+
template <typename T1, typename T2, typename resT, typename IndexerT>
264+
class greater_equal_strided_strided_kernel;
265+
266+
template <typename argTy1, typename argTy2>
267+
sycl::event
268+
greater_equal_strided_impl(sycl::queue exec_q,
269+
size_t nelems,
270+
int nd,
271+
const py::ssize_t *shape_and_strides,
272+
const char *arg1_p,
273+
py::ssize_t arg1_offset,
274+
const char *arg2_p,
275+
py::ssize_t arg2_offset,
276+
char *res_p,
277+
py::ssize_t res_offset,
278+
const std::vector<sycl::event> &depends,
279+
const std::vector<sycl::event> &additional_depends)
280+
{
281+
sycl::event comp_ev = exec_q.submit([&](sycl::handler &cgh) {
282+
cgh.depends_on(depends);
283+
cgh.depends_on(additional_depends);
284+
285+
using resTy =
286+
typename GreaterEqualOutputType<argTy1, argTy2>::value_type;
287+
288+
using IndexerT =
289+
typename dpctl::tensor::offset_utils::ThreeOffsets_StridedIndexer;
290+
291+
IndexerT indexer{nd, arg1_offset, arg2_offset, res_offset,
292+
shape_and_strides};
293+
294+
const argTy1 *arg1_tp = reinterpret_cast<const argTy1 *>(arg1_p);
295+
const argTy2 *arg2_tp = reinterpret_cast<const argTy2 *>(arg2_p);
296+
resTy *res_tp = reinterpret_cast<resTy *>(res_p);
297+
298+
cgh.parallel_for<greater_equal_strided_strided_kernel<argTy1, argTy2,
299+
resTy, IndexerT>>(
300+
{nelems},
301+
GreaterEqualStridedFunctor<argTy1, argTy2, resTy, IndexerT>(
302+
arg1_tp, arg2_tp, res_tp, indexer));
303+
});
304+
return comp_ev;
305+
}
306+
307+
template <typename fnT, typename T1, typename T2>
308+
struct GreaterEqualStridedFactory
309+
{
310+
fnT get()
311+
{
312+
if constexpr (std::is_same_v<
313+
typename GreaterEqualOutputType<T1, T2>::value_type,
314+
void>)
315+
{
316+
fnT fn = nullptr;
317+
return fn;
318+
}
319+
else {
320+
fnT fn = greater_equal_strided_impl<T1, T2>;
321+
return fn;
322+
}
323+
}
324+
};
325+
326+
} // namespace greater_equal
327+
} // namespace kernels
328+
} // namespace tensor
329+
} // namespace dpctl

0 commit comments

Comments
 (0)