Skip to content

Commit 9f0fef3

Browse files
[NFCI][SYCL] Complete refactoring from vec_arith to individual mixins
Refer to #16946 for the context.
1 parent 3dae81e commit 9f0fef3

File tree

4 files changed

+158
-170
lines changed

4 files changed

+158
-170
lines changed

sycl/include/sycl/detail/vector_arith.hpp

+123-135
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ struct UnaryPlus {
5959
}
6060
};
6161

62+
template <typename Op> struct OpAssign {};
63+
6264
// Tag to map/templatize the mixin for prefix/postfix inc/dec operators.
6365
struct IncDec {};
6466

@@ -69,6 +71,10 @@ template <class T> static constexpr bool not_fp = !is_vgenfloat_v<T>;
6971
template <typename Op, typename T>
7072
inline constexpr bool is_op_available_for_type = false;
7173

74+
template <typename Op, typename T>
75+
inline constexpr bool is_op_available_for_type<OpAssign<Op>, T> =
76+
is_op_available_for_type<Op, T>;
77+
7278
#define __SYCL_OP_AVAILABILITY(OP, COND) \
7379
template <typename T> \
7480
inline constexpr bool is_op_available_for_type<OP, T> = COND;
@@ -133,23 +139,63 @@ template <typename SelfOperandTy> struct IncDecImpl {
133139
}
134140
};
135141

142+
// clang-format off
143+
#define __SYCL_INSTANTIATE_OPERATORS(BINOP, OPASSIGN, UOP) \
144+
BINOP(std::plus<void> , +) \
145+
BINOP(std::minus<void> , -) \
146+
BINOP(std::multiplies<void> , *) \
147+
BINOP(std::divides<void> , /) \
148+
BINOP(std::modulus<void> , %) \
149+
BINOP(std::bit_and<void> , &) \
150+
BINOP(std::bit_or<void> , |) \
151+
BINOP(std::bit_xor<void> , ^) \
152+
BINOP(std::equal_to<void> , ==) \
153+
BINOP(std::not_equal_to<void> , !=) \
154+
BINOP(std::less<void> , < ) \
155+
BINOP(std::greater<void> , >) \
156+
BINOP(std::less_equal<void> , <=) \
157+
BINOP(std::greater_equal<void> , >=) \
158+
BINOP(std::logical_and<void> , &&) \
159+
BINOP(std::logical_or<void> , ||) \
160+
BINOP(ShiftLeft , <<) \
161+
BINOP(ShiftRight , >>) \
162+
UOP(std::negate<void> , -) \
163+
UOP(std::logical_not<void> , !) \
164+
/* UOP(std::bit_not<void> , ~) */ \
165+
UOP(UnaryPlus , +) \
166+
OPASSIGN(std::plus<void> , +=) \
167+
OPASSIGN(std::minus<void> , -=) \
168+
OPASSIGN(std::multiplies<void> , *=) \
169+
OPASSIGN(std::divides<void> , /=) \
170+
OPASSIGN(std::modulus<void> , %=) \
171+
OPASSIGN(std::bit_and<void> , &=) \
172+
OPASSIGN(std::bit_or<void> , |=) \
173+
OPASSIGN(std::bit_xor<void> , ^=) \
174+
OPASSIGN(ShiftLeft , <<=) \
175+
OPASSIGN(ShiftRight , >>=)
176+
// clang-format on
177+
178+
template <typename Op>
179+
constexpr bool is_logical =
180+
check_type_in_v<Op, std::equal_to<void>, std::not_equal_to<void>,
181+
std::less<void>, std::greater<void>, std::less_equal<void>,
182+
std::greater_equal<void>, std::logical_and<void>,
183+
std::logical_or<void>, std::logical_not<void>>;
184+
136185
template <typename Self> struct VecOperators {
137186
static_assert(is_vec_v<Self>);
138187

188+
using element_type = typename from_incomplete<Self>::element_type;
189+
static constexpr int N = from_incomplete<Self>::size();
190+
191+
template <typename Op>
192+
using result_t = std::conditional_t<
193+
is_logical<Op>, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;
194+
139195
template <typename OpTy, typename... ArgTys>
140196
static constexpr auto apply(const ArgTys &...Args) {
141197
static_assert(((std::is_same_v<Self, ArgTys> && ...)));
142198

143-
using element_type = typename Self::element_type;
144-
constexpr int N = Self::size();
145-
constexpr bool is_logical = check_type_in_v<
146-
OpTy, std::equal_to<void>, std::not_equal_to<void>, std::less<void>,
147-
std::greater<void>, std::less_equal<void>, std::greater_equal<void>,
148-
std::logical_and<void>, std::logical_or<void>, std::logical_not<void>>;
149-
150-
using result_t = std::conditional_t<
151-
is_logical, vec<fixed_width_signed<sizeof(element_type)>, N>, Self>;
152-
153199
OpTy Op{};
154200
#ifdef __has_extension
155201
#if __has_extension(attribute_ext_vector_type)
@@ -198,7 +244,7 @@ template <typename Self> struct VecOperators {
198244
if constexpr (std::is_same_v<element_type, bool>) {
199245
// Some operations are known to produce the required bit patterns and
200246
// the following post-processing isn't necessary for them:
201-
if constexpr (!is_logical &&
247+
if constexpr (!is_logical<OpTy> &&
202248
!check_type_in_v<OpTy, std::multiplies<void>,
203249
std::divides<void>, std::bit_or<void>,
204250
std::bit_and<void>, std::bit_xor<void>,
@@ -225,13 +271,13 @@ template <typename Self> struct VecOperators {
225271
tmp = reinterpret_cast<decltype(tmp)>((tmp != 0) * -1);
226272
}
227273
}
228-
return bit_cast<result_t>(tmp);
274+
return bit_cast<result_t<OpTy>>(tmp);
229275
}
230276
#endif
231277
#endif
232-
result_t res{};
278+
result_t<OpTy> res{};
233279
for (size_t i = 0; i < N; ++i)
234-
if constexpr (is_logical)
280+
if constexpr (is_logical<OpTy>)
235281
res[i] = Op(Args[i]...) ? -1 : 0;
236282
else
237283
res[i] = Op(Args[i]...);
@@ -246,15 +292,57 @@ template <typename Self> struct VecOperators {
246292
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, IncDec>>>
247293
: public IncDecImpl<Self> {};
248294

295+
#define __SYCL_VEC_BINOP_MIXIN(OP, OPERATOR) \
296+
template <typename Op> \
297+
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
298+
template <typename T = element_type> \
299+
friend std::enable_if_t<is_op_available_for_type<OP, T>, result_t<OP>> \
300+
operator OPERATOR(const Self & Lhs, const Self & Rhs) { \
301+
return apply<OP>(Lhs, Rhs); \
302+
} \
303+
\
304+
template <typename T = element_type> \
305+
friend std::enable_if_t<is_op_available_for_type<OP, T>, result_t<OP>> \
306+
operator OPERATOR(const Self & Lhs, const element_type & Rhs) { \
307+
return OP{}(Lhs, Self{Rhs}); \
308+
} \
309+
template <typename T = element_type> \
310+
friend std::enable_if_t<is_op_available_for_type<OP, T>, result_t<OP>> \
311+
operator OPERATOR(const element_type & Lhs, const Self & Rhs) { \
312+
return OP{}(Self{Lhs}, Rhs); \
313+
} \
314+
};
315+
316+
#define __SYCL_VEC_OPASSIGN_MIXIN(OP, OPERATOR) \
317+
template <typename Op> \
318+
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OpAssign<OP>>>> { \
319+
template <typename T = element_type> \
320+
friend std::enable_if_t<is_op_available_for_type<OP, T>, Self> & \
321+
operator OPERATOR(Self & Lhs, const Self & Rhs) { \
322+
Lhs = OP{}(Lhs, Rhs); \
323+
return Lhs; \
324+
} \
325+
template <int Num = N, typename T = element_type> \
326+
friend std::enable_if_t<(Num != 1) && (is_op_available_for_type<OP, T>), \
327+
Self &> \
328+
operator OPERATOR(Self & Lhs, const element_type & Rhs) { \
329+
Lhs = OP{}(Lhs, Self{Rhs}); \
330+
return Lhs; \
331+
} \
332+
};
333+
249334
#define __SYCL_VEC_UOP_MIXIN(OP, OPERATOR) \
250335
template <typename Op> \
251336
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, OP>>> { \
252337
friend auto operator OPERATOR(const Self &v) { return apply<OP>(v); } \
253338
};
254339

255-
__SYCL_VEC_UOP_MIXIN(std::negate<void>, -)
256-
__SYCL_VEC_UOP_MIXIN(std::logical_not<void>, !)
257-
__SYCL_VEC_UOP_MIXIN(UnaryPlus, +)
340+
__SYCL_INSTANTIATE_OPERATORS(__SYCL_VEC_BINOP_MIXIN,
341+
__SYCL_VEC_OPASSIGN_MIXIN, __SYCL_VEC_UOP_MIXIN)
342+
343+
#undef __SYCL_VEC_UOP_MIXIN
344+
#undef __SYCL_VEC_OPASSIGN_MIXIN
345+
#undef __SYCL_VEC_BINOP_MIXIN
258346

259347
template <typename Op>
260348
struct OpMixin<Op, std::enable_if_t<std::is_same_v<Op, std::bit_not<void>>>> {
@@ -265,127 +353,34 @@ template <typename Self> struct VecOperators {
265353
}
266354
};
267355

268-
#undef __SYCL_VEC_UOP_MIXIN
269-
270356
template <typename... Op>
271357
struct __SYCL_EBO CombineImpl : public OpMixin<Op>... {};
272358

273359
struct Combined
274-
: public CombineImpl<std::negate<void>, std::logical_not<void>,
275-
std::bit_not<void>, UnaryPlus, IncDec> {};
360+
: CombineImpl<std::plus<void>, std::minus<void>, std::multiplies<void>,
361+
std::divides<void>, std::modulus<void>, std::bit_and<void>,
362+
std::bit_or<void>, std::bit_xor<void>, std::equal_to<void>,
363+
std::not_equal_to<void>, std::less<void>,
364+
std::greater<void>, std::less_equal<void>,
365+
std::greater_equal<void>, std::logical_and<void>,
366+
std::logical_or<void>, ShiftLeft, ShiftRight,
367+
std::negate<void>, std::logical_not<void>,
368+
std::bit_not<void>, UnaryPlus, OpAssign<std::plus<void>>,
369+
OpAssign<std::minus<void>>, OpAssign<std::multiplies<void>>,
370+
OpAssign<std::divides<void>>, OpAssign<std::modulus<void>>,
371+
OpAssign<std::bit_and<void>>, OpAssign<std::bit_or<void>>,
372+
OpAssign<std::bit_xor<void>>, OpAssign<ShiftLeft>,
373+
OpAssign<ShiftRight>, IncDec> {};
276374
};
277375

278-
// Macros to populate binary operation on sycl::vec.
279-
#if defined(__SYCL_BINOP)
280-
#error "Undefine __SYCL_BINOP macro"
281-
#endif
282-
283-
#define __SYCL_BINOP(BINOP, OPASSIGN, FUNCTOR) \
284-
template <typename T = DataT> \
285-
friend std::enable_if_t<is_op_available_for_type<FUNCTOR, T>, vec_t> \
286-
operator BINOP(const vec_t & Lhs, const vec_t & Rhs) { \
287-
return VecOperators<vec_t>::template apply<FUNCTOR>(Lhs, Rhs); \
288-
} \
289-
\
290-
template <typename T = DataT> \
291-
friend std::enable_if_t<is_op_available_for_type<FUNCTOR, T>, vec_t> \
292-
operator BINOP(const vec_t & Lhs, const DataT & Rhs) { \
293-
return Lhs BINOP vec_t(Rhs); \
294-
} \
295-
template <typename T = DataT> \
296-
friend std::enable_if_t<is_op_available_for_type<FUNCTOR, T>, vec_t> \
297-
operator BINOP(const DataT & Lhs, const vec_t & Rhs) { \
298-
return vec_t(Lhs) BINOP Rhs; \
299-
} \
300-
template <typename T = DataT> \
301-
friend std::enable_if_t<is_op_available_for_type<FUNCTOR, T>, vec_t> \
302-
&operator OPASSIGN(vec_t & Lhs, const vec_t & Rhs) { \
303-
Lhs = Lhs BINOP Rhs; \
304-
return Lhs; \
305-
} \
306-
template <int Num = NumElements, typename T = DataT> \
307-
friend std::enable_if_t< \
308-
(Num != 1) && (is_op_available_for_type<FUNCTOR, T>), vec_t &> \
309-
operator OPASSIGN(vec_t & Lhs, const DataT & Rhs) { \
310-
Lhs = Lhs BINOP vec_t(Rhs); \
311-
return Lhs; \
312-
}
313-
314376
template <typename DataT, int NumElements>
315-
class vec_arith : public VecOperators<vec<DataT, NumElements>>::Combined {
316-
protected:
317-
using vec_t = vec<DataT, NumElements>;
318-
using ocl_t = detail::fixed_width_signed<sizeof(DataT)>;
319-
320-
// The logical operations on scalar types results in 0/1, while for vec<>,
321-
// logical operations should result in 0 and -1 (similar to OpenCL vectors).
322-
// That's why, for vec<DataT, 1>, we need to invert the result of the logical
323-
// operations since we store vec<DataT, 1> as scalar type on the device.
324-
#if defined(__SYCL_RELLOGOP)
325-
#error "Undefine __SYCL_RELLOGOP macro."
326-
#endif
327-
328-
#define __SYCL_RELLOGOP(RELLOGOP, FUNCTOR) \
329-
template <typename T = DataT> \
330-
friend std::enable_if_t<is_op_available_for_type<FUNCTOR, T>, \
331-
vec<ocl_t, NumElements>> \
332-
operator RELLOGOP(const vec_t & Lhs, const vec_t & Rhs) { \
333-
return VecOperators<vec_t>::template apply<FUNCTOR>(Lhs, Rhs); \
334-
} \
335-
\
336-
template <typename T = DataT> \
337-
friend std::enable_if_t<is_op_available_for_type<FUNCTOR, T>, \
338-
vec<ocl_t, NumElements>> \
339-
operator RELLOGOP(const vec_t & Lhs, const DataT & Rhs) { \
340-
return Lhs RELLOGOP vec_t(Rhs); \
341-
} \
342-
template <typename T = DataT> \
343-
friend std::enable_if_t<is_op_available_for_type<FUNCTOR, T>, \
344-
vec<ocl_t, NumElements>> \
345-
operator RELLOGOP(const DataT & Lhs, const vec_t & Rhs) { \
346-
return vec_t(Lhs) RELLOGOP Rhs; \
347-
}
348-
349-
// OP is: ==, !=, <, >, <=, >=, &&, ||
350-
// vec<RET, NumElements> operatorOP(const vec<DataT, NumElements> &Rhs) const;
351-
// vec<RET, NumElements> operatorOP(const DataT &Rhs) const;
352-
__SYCL_RELLOGOP(==, std::equal_to<void>)
353-
__SYCL_RELLOGOP(!=, std::not_equal_to<void>)
354-
__SYCL_RELLOGOP(>, std::greater<void>)
355-
__SYCL_RELLOGOP(<, std::less<void>)
356-
__SYCL_RELLOGOP(>=, std::greater_equal<void>)
357-
__SYCL_RELLOGOP(<=, std::less_equal<void>)
358-
359-
// Only available to integral types.
360-
__SYCL_RELLOGOP(&&, std::logical_and<void>)
361-
__SYCL_RELLOGOP(||, std::logical_or<void>)
362-
#undef __SYCL_RELLOGOP
363-
#undef RELLOGOP_BASE
364-
365-
// Binary operations on sycl::vec<> for all types except std::byte.
366-
__SYCL_BINOP(+, +=, std::plus<void>)
367-
__SYCL_BINOP(-, -=, std::minus<void>)
368-
__SYCL_BINOP(*, *=, std::multiplies<void>)
369-
__SYCL_BINOP(/, /=, std::divides<void>)
370-
371-
// The following OPs are available only when: DataT != cl_float &&
372-
// DataT != cl_double && DataT != cl_half && DataT != BF16.
373-
__SYCL_BINOP(%, %=, std::modulus<void>)
374-
// Bitwise operations are allowed for std::byte.
375-
__SYCL_BINOP(|, |=, std::bit_or<void>)
376-
__SYCL_BINOP(&, &=, std::bit_and<void>)
377-
__SYCL_BINOP(^, ^=, std::bit_xor<void>)
378-
__SYCL_BINOP(>>, >>=, ShiftRight)
379-
__SYCL_BINOP(<<, <<=, ShiftLeft)
380-
381-
// friends
382-
template <typename T1, int T2> friend class __SYCL_EBO vec;
383-
}; // class vec_arith<>
377+
class vec_arith : public VecOperators<vec<DataT, NumElements>>::Combined {};
384378

385379
#if (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
386380
template <int NumElements>
387381
class vec_arith<std::byte, NumElements>
388-
: public VecOperators<vec<std::byte, NumElements>>::template OpMixin<
382+
: public VecOperators<vec<std::byte, NumElements>>::template CombineImpl<
383+
std::bit_or<void>, std::bit_and<void>, std::bit_xor<void>,
389384
std::bit_not<void>> {
390385
protected:
391386
// NumElements can never be zero. Still using the redundant check to avoid
@@ -426,17 +421,10 @@ class vec_arith<std::byte, NumElements>
426421
Lhs = Lhs >> shift;
427422
return Lhs;
428423
}
429-
430-
__SYCL_BINOP(|, |=, std::bit_or<void>)
431-
__SYCL_BINOP(&, &=, std::bit_and<void>)
432-
__SYCL_BINOP(^, ^=, std::bit_xor<void>)
433-
434-
// friends
435-
template <typename T1, int T2> friend class __SYCL_EBO vec;
436424
};
437425
#endif // (!defined(_HAS_STD_BYTE) || _HAS_STD_BYTE != 0)
438426

439-
#undef __SYCL_BINOP
427+
#undef __SYCL_INSTANTIATE_OPERATORS
440428

441429
} // namespace detail
442430
} // namespace _V1

0 commit comments

Comments
 (0)